In [2]:
import os
import io
import fitz
import base64
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from langchain_core.documents import Document
from langchain.chat_models import init_chat_model
from langchain.prompts import PromptTemplate
from langchain.schema.messages import HumanMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv

import warnings
warnings.filterwarnings("ignore")

In [3]:
load_dotenv()

True

In [4]:
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

In [5]:
llm = init_chat_model("openai:gpt-4.1")

In [6]:
clip_model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

## Embedding Functions

In [12]:
def image_embedding(image_data):
    """ 
    Embedding Images Data using CLIP
    """
    if isinstance(image_data, str):  # If path
        image = Image.open(image_data).convert("RGB")
    else:
        image = image_data

    inputs = clip_processor(images=image, return_tensors="pt")

    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)

        # Normalizing Embeddings to unit vector
        normalized_image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return normalized_image_features.squeeze().numpy()
    
def text_embedding(text_data):
    """ 
    Embedding Text Data using CLIP
    """
    inputs = clip_processor(
        text=text_data,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=77    # CLIP's max token length
    )

    with torch.no_grad():
        text_features = clip_model.get_text_features(**inputs)

        # Normalizing Embeddings to unit vector
        normalized_text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return normalized_text_features.squeeze().numpy()

## Processing PDF

In [13]:
pdf_path = "multimodal_sample.pdf"

document = fitz.open(filename=pdf_path)

all_documents = []
all_embeddings = []
image_data_store = {}

In [14]:
chunk_size=500
chunk_overlap=100

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

In [15]:
document

Document('multimodal_sample.pdf')

In [None]:
for i, page in enumerate(document):
    ## Process Text
    text = page.get_text()
    
    if text.strip():
        # Create temporary document for splitting
        temporary_document = Document(
            page_content=text, metadata={"page": i, "type":"text"}
        )
        text_chunks = text_splitter.split_documents(documents=[temporary_document])

        # Embed Each Chunk using CLIP
        for text_chunk in text_chunks:
            embedding = text_embedding(text_data=text_chunk.page_content)
            all_embeddings.append(embedding)
            all_documents.append(text_chunk)

    ## Process Images: Three Important Actions:
    ## 1. Convert PDF image to PIL format
    ## 2. Store as base64 for GPT-4V (which needs base64 images)
    ## 3. Create CLIP embedding for retrieval

    for image_index, image in enumerate(page.get_images(full=True)):
        try:
            xref = image[0]
            base_image = document.extract_image(xref=xref)
            image_bytes = base_image["image"]

            # Convert to PIL Image
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

            # Create Unique Identifier
            image_id = f"page_{i}_img_{image_index}"

            # Store Image as base64 for later use with Model
            buffered = io.BytesIO()
            pil_image.save(fp=buffered, format="PNG")
            image_base64 = base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = image_base64

            # Embed Image using CLIP
            embedding = image_embedding(image_data=pil_image)
            all_embeddings.append(embedding)

            # Create Document for Image
            image_document = Document(
                page_content=f"[Image: {image_id}]",
                metadata={"page":i, "type":"image", "image_id":image_id}
            )
            all_documents.append(image_document)

        except Exception as e:
            print(f"Error processing image {image_index} on page {i}: {e}")
            continue

document.close()

In [17]:
all_documents

[Document(metadata={'page': 0, 'type': 'text'}, page_content='Annual Revenue Overview\nThis document summarizes the revenue trends across Q1, Q2, and Q3. As illustrated in the chart\nbelow, revenue grew steadily with the highest growth recorded in Q3.\nQ1 showed a moderate increase in revenue as new product lines were introduced. Q2 outperformed\nQ1 due to marketing campaigns. Q3 had exponential growth due to global expansion.'),
 Document(metadata={'page': 0, 'type': 'image', 'image_id': 'page_0_img_0'}, page_content='[Image: page_0_img_0]')]

## Create Unified Vector Store with CLIP Embedding

In [18]:
embeddings_array = np.array(all_embeddings)
embeddings_array

array([[-0.00267245,  0.01283001, -0.05183142, ..., -0.00385085,
         0.02977719, -0.00010682],
       [ 0.01732343, -0.01327688, -0.02427032, ...,  0.08994049,
        -0.00272151,  0.0325304 ]], shape=(2, 512), dtype=float32)

In [19]:
vector_store = FAISS.from_embeddings(
    text_embeddings=[(document.page_content, embedding) for document, embedding in zip(all_documents, embeddings_array)],
    embedding=None,  # We're using precomputed embeddings
    metadatas=[document.metadata for document in all_documents]
)

vector_store

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


<langchain_community.vectorstores.faiss.FAISS at 0x15e8001b1d0>

## Create Retrieval Function

In [22]:
def retrieve_multimodal(query, k=5):
    """ 
    Unified Retrieval using CLIP Embeddings for both text and images
    """
    # Embed Query using CLIP
    query_embedding = text_embedding(text_data=query)

    # Search in Unified Vector Store
    results = vector_store.similarity_search_by_vector(
        embedding=query_embedding,
        k=k
    )

    return results

In [24]:
def create_multimodal_message(query, retrieved_documents):
    """ 
    Create a message with both text and images for Model
    """
    content = []
    
    # Add the query
    content.append({
        "type": "text",
        "text": f"Question: {query}\n\nContext:\n"
    })
    
    # Separate text and image documents
    text_documents = [document for document in retrieved_documents if document.metadata.get("type") == "text"]
    image_documents = [document for document in retrieved_documents if document.metadata.get("type") == "image"]
    
    # Add text context
    if text_documents:
        text_context = "\n\n".join([
            f"[Page {text_document.metadata['page']}]: {text_document.page_content}"
            for text_document in text_documents
        ])
        content.append({
            "type": "text",
            "text": f"Text excerpts:\n{text_context}\n"
        })
    
    # Add images
    for image_document in image_documents:
        image_id = image_document.metadata.get("image_id")
        if image_id and image_id in image_data_store:
            content.append({
                "type": "text",
                "text": f"\n[Image from page {image_document.metadata['page']}]:\n"
            })
            content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{image_data_store[image_id]}"
                }
            })
    
    # Add instruction
    content.append({
        "type": "text",
        "text": "\n\nPlease answer the question based on the provided text and images."
    })
    
    return HumanMessage(content=content)

## Create Multimodal RAG Pipeline

In [25]:
def multimodal_pdf_rag_pipeline(query):
    """
    Main pipeline for multimodal RAG.
    """
    # Retrieve relevant documents
    retrieved_documents = retrieve_multimodal(query=query, k=5)
    
    # Create multimodal message
    message = create_multimodal_message(query=query, retrieved_documents=retrieved_documents)
    
    # Get response from LLM
    response = llm.invoke([message])
    
    # Print retrieved context info
    print(f"\nRetrieved {len(retrieved_documents)} documents:")
    for retrieved_document in retrieved_documents:
        document_type = retrieved_document.metadata.get("type", "unknown")
        page = retrieved_document.metadata.get("page", "?")
        if document_type == "text":
            preview = retrieved_document.page_content[:100] + "..." if len(retrieved_document.page_content) > 100 else retrieved_document.page_content
            print(f"  - Text from page {page}: {preview}")
        else:
            print(f"  - Image from page {page}")
    print("\n")
    
    return response.content

## Query

In [26]:
queries = [
    "What does the chart on page 1 show about revenue trends?",
    "Summarize the main findings from the document",
    "What visual elements are present in the document?"
]

In [27]:
for query in queries:
    print(f"\nQuery: {query}")
    print("-" * 50)

    answer = multimodal_pdf_rag_pipeline(query)
    
    print(f"Answer: {answer}")
    print("=" * 70)


Query: What does the chart on page 1 show about revenue trends?
--------------------------------------------------

Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: The chart on page 1 shows that revenue increased steadily across Q1, Q2, and Q3. In Q1 (blue bar), revenue was the lowest but showed a moderate increase, likely due to new product lines. In Q2 (green bar), revenue grew further, driven by marketing campaigns. Q3 (red bar) had the highest revenue, indicating exponential growth attributed to global expansion. Overall, the chart demonstrates a consistent upward trend in revenue over the three quarters, with the largest jump occurring in Q3.

Query: Summarize the main findings from the document
--------------------------------------------------

Retrieved 2 documents:
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenu