# Multimodal Retrieval Augmented Generation (RAG) with Gemini, Vertex AI Vector Search, and LangChain

---

# Part 2: User Query Process

## Getting Started

### Install Vertex AI SDK for Python and other dependencies

In [None]:
%pip install -U -q google-cloud-aiplatform langchain-core langchain-google-vertexai langchain-text-splitters langchain-community "unstructured[all-docs]" pypdf pydantic lxml pillow matplotlib opencv-python tiktoken

### Restart current runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel.

In [None]:
# Restart kernel after installs so that your environment can access the new packages
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the following cell to authenticate your environment. This step is not required if you are using [Vertex AI Workbench](https://cloud.google.com/vertex-ai-workbench).

In [None]:
import sys

# Additional authentication is required for Google Colab
if "google.colab" in sys.modules:
    # Authenticate user to Google Cloud
    from google.colab import auth

    auth.authenticate_user()

### Define Google Cloud project information

In [None]:
PROJECT_ID = "sublime-vine-445509-s8"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

# For Vector Search Staging
GCS_BUCKET = "hashem_yaazor"  # @param {type:"string"}
GCS_BUCKET_URI = f"gs://{GCS_BUCKET}"

### Initialize the Vertex AI SDK

In [None]:
from google.cloud import aiplatform

aiplatform.init(project=PROJECT_ID, location=LOCATION, staging_bucket=GCS_BUCKET_URI)

### Import libraries

In [None]:
import base64
import os
import re

from IPython.display import Image, Markdown, display
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_vertexai import (
    ChatVertexAI,
    VectorSearchVectorStore,
    VertexAIEmbeddings,
)
from langchain_google_vertexai.vectorstores.document_storage import GCSDocumentStorage
from google.cloud import storage

### Define model information

- [Vertex AI - Model Information](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)

In [None]:
MODEL_NAME = "gemini-1.5-flash"
GEMINI_OUTPUT_TOKEN_LIMIT = 8192

EMBEDDING_MODEL_NAME = "text-embedding-004"
EMBEDDING_TOKEN_LIMIT = 2048

TOKEN_LIMIT = min(GEMINI_OUTPUT_TOKEN_LIMIT, EMBEDDING_TOKEN_LIMIT)

### Define index information

Connect to exist vertex AI vector search:

*   Copy index id from https://console.cloud.google.com/vertex-ai/matching-engine/indexes
*   Copy endpoint id from https://console.cloud.google.com/vertex-ai/matching-engine/index-endpoints



In [None]:
INDEX_ID = '4488078921232809984' #@param {type: "string"}
ENDPOINT_ID = '7764658756377378816' #@param {type: "string"}

## Create retriever

- Create [`VectorSearchVectorStore`](https://api.python.langchain.com/en/latest/vectorstores/langchain_google_vertexai.vectorstores.vectorstores.VectorSearchVectorStore.html) with Vector Search Index ID and Endpoint ID.
- Use [`textembedding-gecko`](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings) as embedding model.

In [None]:
# The vectorstore to use to index the summaries
vectorstore = VectorSearchVectorStore.from_components(
    project_id=PROJECT_ID,
    region=LOCATION,
    gcs_bucket_name=GCS_BUCKET,
    index_id=INDEX_ID,
    endpoint_id=ENDPOINT_ID,
    embedding=VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME),
    stream_update=True,
)

- Create Multi-Vector Retriever using the vector store you created.
- Since vector stores only contain the embedding and an ID, you'll also need to create a document store indexed by ID to get the original source documents after searching for embeddings.

In [None]:
try:
    storage_client = storage.Client()
    bucket = storage_client.bucket(GCS_BUCKET)
    if not bucket.exists():
        raise ValueError(f"Bucket '{GCS_BUCKET}' does not exist.")
except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
docstore = GCSDocumentStorage(bucket, "RFM_chunks")

id_key = "doc_id"
retriever_multi_vector_img = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=docstore,
    id_key=id_key,
)

## Create Chain with Retriever and Gemini LLM

In [None]:
def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xFF\xD8\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\xa47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}


def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = [
        {
            "type": "text",
            "text": (
                """You are a learning assistant tasked with helping trainees in\
                 the pilot course understand the 'Ofer' helicopter systems and \
                 operating instructions. You will receive a mix of content, \
                 including text, tables, and images, often in the form of charts\
                  or graphs.
Provide clear, accurate, and professional answers to the user's questions about\
 the helicopter, using the information provided. Your responses should focus \
 exclusively on addressing the user's question in a concise and professional \
 manner, based on the materials you receive. Do not include any remarks about \
 the source of the materials, their format, or how they were compiled. Simply \
 deliver the most relevant and comprehensive answer to the question, ensuring \
 it aligns with the provided information."""
                f"User-provided question: {data_dict['question']}\n\n"
                "Text and / or tables:\n"
                f"{formatted_texts}"
            ),
        }
    ]

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            messages.append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image}"},
                }
            )
    return [HumanMessage(content=messages)]

# Create RAG chain
chain_multimodal_rag = (
    {
        "context": RunnableLambda(lambda x: source_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(img_prompt_func)
    | ChatVertexAI(
        temperature=0,
        model_name=MODEL_NAME,
        max_output_tokens=TOKEN_LIMIT,
    )  # Multi-modal LLM
    | StrOutputParser()
)

In [None]:
def sources_details(docs):
    docs_details = {}
    for doc in docs:
        if doc.metadata['filename'] in docs_details:
            docs_details[doc.metadata['filename']].append(int(doc.metadata['page_number']) + 1)
        else:
            docs_details[doc.metadata['filename']] = [int(doc.metadata['page_number']) + 1]
    return docs_details

def sources_details_string(docs):
    sources = sources_details(docs)

    sources_strings = []
    for key, val in sources.items():
        pages = ', '.join(map(str, val))
        page_word = "page" if len(val) == 1 else "pages"
        sources_strings.append(f'The file "{key}" {page_word} {pages}')

    result = 'The answer is based on the following files:\n' + '\n'.join(sources_strings)
    return result

## User query process

In [None]:
query = """"what is palacards?"""

In [None]:
docs = retriever_multi_vector_img.invoke(query, limit=10)
source_docs = split_image_text_types(docs)
source_docs

### Get Retrieved documents

### Get generative response

In [None]:
result = chain_multimodal_rag.invoke(query)

Markdown(result)

In [None]:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/content/sublime-vine-445509-s8-e9eb269d62b2.json"

In [None]:
from datetime import timedelta

def generate_signed_url(folder, filename):
    storage_client = storage.Client()
    bucket = storage_client.bucket(GCS_BUCKET)
    blob = bucket.blob(os.path.join(folder, filename))

    url = blob.generate_signed_url(
        version="v4",
        expiration=timedelta(hours=1),
        method="GET",
    )

    return url

folder = "ofer_documents"
filename = "RFM עדכון 8.2.24.pdf"
url = generate_signed_url(folder, filename)
print(url)

In [None]:
def list_blobs_with_prefix(bucket_name, prefix):
    storage_client = storage.Client()
    blobs = storage_client.list_blobs(bucket_name, prefix=prefix)

    urls = []
    for blob in blobs:
        url = blob.public_url
        urls.append(url)

    return urls

prefix = "ofer_documents"
urls = list_blobs_with_prefix(GCS_BUCKET, prefix)
print(urls)

In [None]:
files = sources_details(texts_docs)
files_string = sources_details_string(texts_docs)
print(files.keys())
for file in files.keys():
    print(f"File: {file}")

In [None]:
folder = "pdf_documents"

texts_docs = [doc for doc in docs if not looks_like_base64(doc.page_content) and not is_image_data(doc.page_content)]
files = sources_details(texts_docs)
for file in files.keys():
    print(generate_signed_url(folder, file))

for i in source_docs["images"]:
    display(Image(base64.b64decode(i)))