# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain


---

# Part 1: Extraction, Summary and Storage Data

## Getting Started

### Install Vertex AI SDK for Python and other dependencies

In [None]:
%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community "unstructured[all-docs]" pypdf pydantic lxml pillow matplotlib opencv-python

### Restart current runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel.

In [None]:
# Restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>
</div>




### Authenticate your notebook environment (Colab only)



If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench).


In [None]:
import sys

# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

### Define Google Cloud project information

In [None]:
PROJECT_ID = "sublime-vine-445509-s8"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

# For Vector Search Staging
GCS_BUCKET = "hashem_yaazor"  # @param {type:"string"}
GCS_BUCKET_URI = f"gs://{GCS_BUCKET}"

### Initialize the Vertex AI SDK

In [None]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)

### Import libraries

In [None]:
!apt-get install -y poppler-utils # Install the poppler-utils package, which contains pdfinfo
!pip install unstructured[pdf] # Install unstructured with extra dependencies for PDF support
!apt install tesseract-ocr # Install the tesseract OCR engine
!pip install pytesseract  # Install the Python wrapper for tesseract
import nltk
nltk.download('all')
!cp -R ~/nltk_data/tokenizers/punkt/PY3 ~/nltk_data/tokenizers/punkt/PY3_tab

In [None]:
import base64
import os
import uuid

from google.cloud import storage
from IPython.display import Image
from langchain.prompts import PromptTemplate
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_google_vertexai import (
    ChatVertexAI,
    VectorSearchVectorStore,
    VertexAI,
    VertexAIEmbeddings,
)
from langchain_google_vertexai.vectorstores.document_storage import GCSDocumentStorage
from unstructured.partition.pdf import partition_pdf

### Define model information

- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)

In [None]:
MODEL_NAME = "gemini-1.5-flash"
GEMINI_OUTPUT_TOKEN_LIMIT = 8192

EMBEDDING_MODEL_NAME = "text-embedding-004"
EMBEDDING_TOKEN_LIMIT = 4096

TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)

## Data Loading

#### Get documents and images

In [None]:
# Install documents

## Partition PDF tables, text, and images

### Extract data

In [None]:
pdf_folder_path = "/content/" if "google.colab" in sys.modules else "data/"
pdf_file_name = "/content/RFM עדכון 8.2.24.pdf"

# Extract images, tables, and chunk text from a PDF file.
raw_pdf_elements = partition_pdf(
    filename=pdf_file_name,
    strategy="hi_res",
    extract_images_in_pdf=False,
    extract_image_block_to_payload=False,
    infer_table_structure=True,
    chunking_strategy="by_title",
    max_characters=4000,
    new_after_n_chars=3800,
    combine_text_under_n_chars=2000,
    unique_element_ids=True
)

In [None]:
tables_elements = []
texts_elements = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        tables_elements.append(element)
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        texts_elements.append(element)

tables = [element.text for element in tables_elements]
texts = [element.text for element in texts_elements]

### Generate summaries

In [None]:
def generate_summaries(
    chunks: list[str], summarize: bool = True
) -> list[str]:
    prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
    These summaries will be embedded and used to retrieve the raw text or table elements. \
    Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
    prompt = PromptTemplate.from_template(prompt_text)
    empty_response = RunnableLambda(
        lambda x: AIMessage(content="Error processing document")
    )
    model = VertexAI(
        temperature=0, model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT
    ).with_fallbacks([empty_response])
    summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

    summaries = []

    if chunks:
        if summarize:
            summaries = summarize_chain.batch(chunks, {"max_concurrency": 1})
        else:
            summaries = chunks

    return summaries

text_summaries = generate_summaries(texts)
table_summaries = generate_summaries(tables)

In [None]:
import time

def encode_image(image_path: str) -> str:
    """Getting the base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def image_summarize(base64_image: str) -> str:
    prompt = """You are an assistant tasked with summarizing images for retrieval. \
    These summaries will be embedded and used to retrieve the raw image. \
    Give a concise summary of the image that is well optimized for retrieval.
    If it's a table, extract all elements of the table.
    If it's a graph, explain the findings in the graph.
    Do not include any numbers that are not mentioned in the image.
    """
    model = ChatVertexAI(model_name=MODEL_NAME, max_output_tokens=TOKEN_LIMIT)
    msg = model.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{base64_image}"},
                    },
                ]
            )
        ]
    )
    return msg.content


def generate_img_summaries(path: str) -> tuple[list[str], list[str]]:
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by Unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    for root, _, files in os.walk(path):
        for img_file in sorted(files):
            if img_file.endswith(".png"):
                base64_image = encode_image(os.path.join(root, img_file))
                img_base64_list.append(base64_image)
                image_summaries.append(image_summarize(base64_image))
                time.sleep(5)
    return img_base64_list, image_summaries


# Image summaries
img_base64_list, image_summaries = generate_img_summaries("/content/images")

In [None]:
def get_missing_summaries(
    texts: list[str], summaries: list[str], indexes: bool = False
    ) -> list[str]:
    return [i if indexes else texts[i] for i, text in enumerate(summaries) if len(summaries[i])==0 or  summaries[i] =='Error processing document']

print(len(get_missing_summaries(texts, text_summaries)))
print(get_missing_summaries(texts, text_summaries))
print(len(get_missing_summaries(tables, table_summaries)))
print(get_missing_summaries(tables, table_summaries))
print(get_missing_summaries(img_base64_list, image_summaries))

In [None]:
MAX_TRIES = 6 # @param {type: "integer"}

In [None]:
def resummarize_missing_summaries(
    chunks: list[str], current_summaries: list[str], N: int, is_image: bool = False
    ) -> list[str]:

    summaries = current_summaries

    chunks_to_summarize  = get_missing_summaries(chunks, current_summaries)
    missing_summaries_indexes = get_missing_summaries(chunks, current_summaries, indexes=True)

    if is_image:
        new_summaries = [image_summarize(image) for image in chunks_to_summarize]
    else:
        if N == MAX_TRIES - 1:
            new_summaries = generate_summaries(chunks_to_summarize, summarize=False)
        else:
            new_summaries = generate_summaries(chunks_to_summarize, summarize=True)

    for i, summary in enumerate(new_summaries):
        summaries[missing_summaries_indexes[i]] = new_summaries[i]

    return summaries


for i in range(MAX_TRIES):
    text_summaries = resummarize_missing_summaries(texts, text_summaries, N=i)
    table_summaries = resummarize_missing_summaries(tables, table_summaries, N=i)
    image_summaries = resummarize_missing_summaries(img_base64_list, image_summaries, N=i, is_image=True)

## Create & Deploy Vertex AI Vector Search Index & Endpoint

Skip this step if you already have Vector Search set up.

- https://console.cloud.google.com/vertex-ai/matching-engine/indexes

- Create [`MatchingEngineIndex`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndex)
  - https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index

In [None]:
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings
DIMENSIONS = 768  # Dimensions output from textembedding-gecko

index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name="mm_rag_langchain_index",
    dimensions=DIMENSIONS,
    approximate_neighbors_count=150,
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description="Multimodal RAG LangChain Index",
    index_update_method="STREAM_UPDATE",
)

- Create [`MatchingEngineIndexEndpoint`](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.MatchingEngineIndexEndpoint)
  - https://cloud.google.com/vertex-ai/docs/vector-search/deploy-index-public

In [None]:
DEPLOYED_INDEX_ID = "mm_rag_langchain_index_endpoint"

index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name=DEPLOYED_INDEX_ID,
    description="Multimodal RAG LangChain Index Endpoint",
    public_endpoint_enabled=True,
)

- Deploy Index to Index Endpoint
  - NOTE: This will take a while to run.
  - You can stop this cell after starting it instead of waiting for deployment.
  - You can check the status at https://console.cloud.google.com/vertex-ai/matching-engine/indexes

In [None]:
index_endpoint = index_endpoint.deploy_index(
    index=index, deployed_index_id="mm_rag_langchain_deployed_index"
)
index_endpoint.deployed_indexes

## Create retriever & load documents

- Create [`VectorSearchVectorStore`](https://python.langchain.com/api_reference/google_vertexai/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.
- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model.

In [None]:
# The vectorstore to use to index the summaries
vectorstore = VectorSearchVectorStore.from_components(
    project_id=PROJECT_ID,
    region=LOCATION,
    gcs_bucket_name=GCS_BUCKET,
    index_id=index.name,
    endpoint_id=index_endpoint.name,
    embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),
    stream_update=True,
)

- Create Multi-Vector Retriever using the vector store you created.
- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings.

In [None]:
try:
    storage_client = storage.Client()
    bucket = storage_client.bucket(GCS_BUCKET)
    if not bucket.exists():
        raise ValueError(f"Bucket '{GCS_BUCKET}' does not exist.")
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
docstore = GCSDocumentStorage(bucket, "RFM_chunks")

id_key = "doc_id"
retriever_multi_vector_img = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=docstore,
    id_key=id_key,
)

- Load data into Document Store and Vector Store

In [None]:
chunks_documents = [
    Document(page_content=text_element.text,
             metadata={**text_element.metadata.to_dict(), id_key: text_element.id})
    for text_element in texts_elements
] + [
    Document(page_content=table_element.text,
             metadata={**table_element.metadata.to_dict(), id_key: table_element.id})
    for table_element in tables_elements
] + [
    Document(page_content=img_base64,
             metadata={id_key: str(uuid.uuid4())})
    for img_base64 in img_base64_list
]

doc_ids = [doc.metadata[id_key] for doc in chunks_documents]

summary_docs = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(text_summaries + table_summaries + image_summaries)
]

retriever_multi_vector_img.docstore.mset(list(zip(doc_ids, chunks_documents)))

In [None]:
def batch(iterable, batch_size=1000):
    for i in range(0, len(iterable), batch_size):
        yield iterable[i:i + batch_size]

# Split to batches with a max size of 1,000
batch_size = 1000
batches = list(batch(summary_docs, batch_size))

for batch_docs in batches:
    retriever_multi_vector_img.vectorstore.add_documents(batch_docs)
