### Multimodel RAG

In [11]:
import fitz
from langchain_core.documents import Document
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
from langchain.chat_models import init_chat_model
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from sklearn.metrics.pairwise import cosine_similarity
import os
import base64
import io
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS

In [12]:
from dotenv import load_dotenv
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

## initialize Clip model for unified embeddings
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

clip_model.eval()  # Set the model to evaluation mode

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [13]:
## Embedding functions
def embed_image(image_data):
    """Embed image using CLIP model."""
    if isinstance(image_data, str):
        # If image_data is a base64 string, decode it
        image = Image.open(image_data).convert("RGB")
    else: # if PIL Image object
        image = image_data


    inputs = clip_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
        # Normalize the features
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().cpu().numpy()
    
def embed_text(text):
    """Embed text using CLIP."""
    inputs = clip_processor(
        text=text, 
        return_tensors="pt", 
        padding=True,
        truncation=True,
        max_length=77  # CLIP's max token length
    )
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
        # Normalize embeddings
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [14]:
pdf_path = "multimodal_sample.pdf"
doc= fitz.open(pdf_path)

# Storage for all documents and embeddings
all_docs = []
all_embeddings = []
image_data_store = {} # store actual image data for llm

# Text Splitter
splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=100)

In [15]:
doc

Document('multimodal_sample.pdf')

In [16]:
for i, page in enumerate(doc):
    # process text
    text = page.get_text()
    if text.strip():  # Only process non-empty pages
        temp_doc = Document(page_content=text, metadata={"page": i, "type": "text"})
        text_chunks = splitter.split_documents([temp_doc])

        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_embeddings.append(embedding)
            all_docs.append(chunk)

    # process images
    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref = img[0]
            base_image = doc.extract_image(xref)
            image_bytes = base_image["image"]
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

            # create image identifier
            image_id = f"page_{i}_img_{img_index}"
            
            # Store the image data for later use
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            image_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = image_base64

            # Embed the image
            embedding = embed_image(pil_image)
            all_embeddings.append(embedding)
            
            # Create a Document for the image
            img_doc = Document(page_content=f"[Image {image_id}]", metadata={"page": i, "type": "image", "index": image_id})
            all_docs.append(img_doc)
        except Exception as e:
            print(f"Error processing image on page {i}, index {img_index}: {e}")

In [17]:
all_embeddings

[array([-2.67243758e-03,  1.28299380e-02, -5.18314578e-02,  4.14879806e-02,
        -2.33941991e-02, -7.55867921e-03, -3.67659107e-02,  1.19710788e-01,
         8.52081627e-02,  2.05414207e-03, -1.11533785e-02, -1.29592577e-02,
         5.25014475e-02, -3.65390349e-03,  4.76078279e-02,  1.58372652e-02,
         2.03387495e-02,  4.35361564e-02, -3.29173729e-03,  2.03181189e-02,
         1.88017765e-03, -4.23493832e-02,  5.44102024e-03,  3.70934680e-02,
        -1.65622756e-02,  6.48646755e-03, -4.78011742e-02,  8.67477432e-03,
         5.88859841e-02, -3.21394093e-02,  4.32439931e-02,  9.65298619e-03,
        -4.47922898e-03, -1.94856990e-02, -3.63503098e-02, -1.23472419e-02,
        -2.17928477e-02, -1.99016798e-02,  8.09620097e-02, -3.32987122e-02,
        -2.38900762e-02, -3.96138281e-02, -1.27279637e-02,  3.50380875e-02,
        -2.52217352e-02,  2.00032187e-03,  1.49660530e-02, -2.31977534e-02,
        -6.86791688e-02, -5.25773445e-04, -2.22545322e-02, -1.04104020e-02,
        -1.9

In [18]:
all_docs

[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'index': 'page_0_img_0'}, page_content='[Image page_0_img_0]')]

In [20]:
# craete unified vector store with CLIP embeddings
embedding_array = np.array(all_embeddings)

# create custom FAISS index since we have precomputed embeddings
vector_store = FAISS.from_embeddings(
    text_embeddings=[(doc.page_content, emb) for doc, emb in zip(all_docs, embedding_array)],
    embedding=None,  # We already have embeddings
    metadatas= [doc.metadata for doc in all_docs]
)

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [22]:
(all_docs, embedding_array)

([Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
  Document(metadata={'page': 0, 'type': 'image', 'index': 'page_0_img_0'}, page_content='[Image page_0_img_0]')],
 array([[-0.00267244,  0.01282994, -0.05183146, ..., -0.00385089,
          0.02977717, -0.0001069 ],
        [ 0.01732345, -0.01327708, -0.0242704 , ...,  0.08993968,
         -0.00272156,  0.03253062]], shape=(2, 512), dtype=float32))

In [23]:
vector_store

<langchain_community.vectorstores.faiss.FAISS at 0x163099790>

In [28]:
def retrieve_multimodal(query, k=5):
    """Retrieve documents based on a multimodal query."""
    # Embed the query
    query_embedding = embed_text(query)
    
    # Perform similarity search
    results = vector_store.similarity_search_by_vector(embedding=query_embedding, k=k)
    
    return results

In [32]:
def create_multimodal_message(query, retrieved_docs):
    """Create a message with both text and image content for GPT4-V"""

    content = []
    # Add the query
    content.append({"type": "text", "text": f"Question: {query}\n\nContext:\n"})

    # Seperate text and image content
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "text"]
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "image"]

    # Add text context
    if text_docs:
        text_context = "\n\n".join([
            f"[Page {doc.metadata['page']}]: {doc.page_content}" for doc in text_docs
        ])
        content.append({"type": "text", "text": f"Text excerpts:\n{text_context}\n\n"})

    # Add image context
    if image_docs:
        for doc in image_docs:
            image_id = doc.metadata["index"]
            if image_id and image_id in image_data_store:
                content.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{image_data_store[image_id]}"
                    }
                })
                content.append({"type": "text", "text": f"Image {image_id} from page {doc.metadata['page']}."}) 

    # Add instructions for the model
    content.append({
        "type": "text",
        "text": "\n\nPlease answer the question based on the provided context (text and images). If you cannot answer, say 'I don't know'."
    })

    return HumanMessage(
        content=content,
        additional_kwargs={"role": "user"}
    )

In [33]:
def multimodal_pdf_rag_pipeline(query):
    """Main pipeline for multimodal RAG with PDF documents."""
    # Step 1: Retrieve relevant documents
    retrieved_docs = retrieve_multimodal(query)

    # Step 2: Create a multimodal message for the model
    multimodal_message = create_multimodal_message(query, retrieved_docs)

    # Step 3: Initialize the chat model (GPT-4V)
    llm = init_chat_model("openai:gpt-4.1")

    # Step 4: Get the response from the model
    response = llm.invoke([multimodal_message])

    # Print retrieved context info
    print(f"Retrieved {len(retrieved_docs)} documents:")
    for doc in retrieved_docs:
        doc_type = doc.metadata.get("type", "unknown")
        page = doc.metadata.get("page", "N/A")
        if doc_type == "text":
            preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            print(f"Text [Page {page}]: {preview}")
        else:
            image_id = doc.metadata.get("index", "unknown")
            print(f"Image [ID: {image_id}] from Page {page}")

    return response.content

In [34]:
if __name__ == "__main__":
    # Example queries
    queries = [
        "What does the chart on page 1 show about revenue trends?",
        "Summarize the main findings from the document",
        "What visual elements are present in the document?"
    ]
    
    for query in queries:
        print(f"\nQuery: {query}")
        print("-" * 50)
        answer = multimodal_pdf_rag_pipeline(query)
        print(f"Answer: {answer}")
        print("=" * 70)


Query: What does the chart on page 1 show about revenue trends?
--------------------------------------------------
Retrieved 2 documents:
Text [Page 0]: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
Image [ID: page_0_img_0] from Page 0
Answer: The chart on page 1 shows that revenue increased consistently across Q1, Q2, and Q3. Q1 had the lowest revenue, Q2 saw more growth, and Q3 experienced the highest and most significant increase in revenue. This matches the text, which explains a moderate rise in Q1, a boost in Q2 due to marketing, and exponential growth in Q3 due to global expansion.

Query: Summarize the main findings from the document
--------------------------------------------------
Retrieved 2 documents:
Text [Page 0]: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
Image [ID: page_0_img_0] from Page 0
Answer: The main findings from the document are:

- Revenue i