# Multi-Modal RAG

Before reading this notebook, please make sure to have read the first document `text_rag.ipynb`.

### What is Multi-modal RAG?

Multi-modal Retrieval-Augmented Generation (RAG) extends the standard RAG approach by incorporating **multiple types of data**—such as text, images, or even audio—**into the retrieval and generation process**. Instead of working with a single modality (like just text), multi-modal RAG systems can query and generate content based on various forms of input, allowing for richer and more diverse responses. 

For example, when dealing with a document that contains both images and text, a multi-modal RAG system can retrieve relevant images along with the associated text, enhancing the quality and relevance of the generated response.



![Multi-Modal RAG Image](../data/multi_modal_rag.jpg)



### Approaches for Working with Text and Images

Multi-modal RAG (Retrieval-Augmented Generation) systems differ in how they handle text and images, depending on whether both the database and the language model (LLM) work with text and images together or focus on one modality. There are a lot of different approaches to reach this goal.

Another approach not shown above is to consider the file as a series of images:
- **File-as-Images** → **Image DB** → **Retrieve Images** → **Multi-modal LLM** → **Text answer and sources**  
- In this case, the document is converted into a series of images (e.g., scanned pages), stored in an image-specific database, and sent as images to a multi-modal LLM, which generates textual answers based on the content of the images. The DB used is usually specialized for this task.

In this exercise, you will learn how to implement a **Multimodal Retrieval-Augmented Generation (RAG)** pipeline from scratch, without relying on tools like `langchain`. Here, two different vector store are used to **store the images and text separately**.

The different components of the pipeline are:

- **Text and image extraction from PDFs** – Extract raw text and images from PDF files to make the content processable.  
- **Text and image chunking** – Break the extracted text and images into smaller, meaningful segments to improve retrieval efficiency.  
- **Embedding of the chunks (text and images)** – Convert text and image chunks into numerical representations (embeddings) using pre-trained models.  
- **Storage of the embeddings in a vector store** – Save both text and image embeddings in a specialized database (vector store) to enable fast similarity searches.  
- **Relevant chunks retrieval** – Query the vector store to find the most relevant text and image chunks based on user input.  
- **Setting and prompting of the LLM for a RAG** – Structure prompts and configure the language model to integrate retrieved text and image information into its responses.  
- **Additional tools for improved retrieval** – Use techniques like query expansion to reformulate user queries for better recall and reciprocal rank fusion to combine results from multiple retrieval methods.  
- **Final multimodal RAG pipeline implementation** – Integrate all components into a complete system that retrieves relevant information (both text and images) and generates enhanced responses using the language model.

**Note:** To complete this exercise, you need an OpenAI API key, the PDF files with images, and the necessary libraries installed (see `requirements.txt`).

In [None]:
!pip install -r requirements.txt

In [None]:
import io
import os
import getpass
import json
from tqdm import tqdm

import numpy as np

import base64
import matplotlib.pyplot as plt
from PIL import Image

from src.data_classes import Chunk, DataType, Roles
from src.data_processing import PDFExtractorAPI, SimpleChunker
from src.embedding import (
    OpenAITextEmbeddings,
    VLM2VecImageEmbeddings,
    VLM2VecTextEmbeddings,
    compute_openai_large_embedding_cost,
)
from src.vectorstore import (
    ChromaDBVectorStore,
    VectorStoreRetriever,
)
from src.llm import OpenAILLM
from src.rag import Generator, DefaultRAG, query_expansion

In [None]:
data_folder = "../data"

pdf_files = [
    "Explainable_machine_learning_prediction_of_edema_a.pdf",
    "Modeling tumor size dynamics based on real‐world electronic health records.pdf",
]
example_pdf_file = "Explainable_machine_learning_prediction_of_edema_a.pdf"
example_pdf_path = os.path.join(data_folder, example_pdf_file)

text_vector_store_collection = "text_collection"
image_vector_store_collection = "image_collection"

text_vector_store_full_collection = "text_collection_full"
image_vector_store_full_collection = "image_collection_full"

In [None]:
os.environ["OPENAI_API_KEY"] = getpass.getpass()

# Example

The example uses only `Explainable_machine_learning_prediction_of_edema_a.pdf`. Please, have a quick look at it before starting the exercise.

In [None]:
test_question = "According to SHAP analysis, which factors were the most influential in predicting higher-grade edema (Grade 2+)?"

## PDF Text and Images Extraction  

The first step in the pipeline is to extract text and images from the document.  

In this exercise, we use the `MinerU` library, which under the hood uses among others `doclayout_yolo` for segmentation. Note that this model is not commercially permissive.

Extracting images can be challenging, as **irrelevant images** (such as logos) are often included, and some images may be **split into multiple images**. It may also be helpful to link the position of images to nearby text for more accurate retrieval. Specialized tools or methods might be required to efficiently handle images embedded in the document.

In [None]:
data_extractor = PDFExtractorAPI()
_, text, images = data_extractor.extract_text_and_images(example_pdf_path)

In [None]:
print(text[:1000])

In [None]:
img_data = base64.b64decode(images[2]["image_base64"])
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
img_data = base64.b64decode(images[0]["image_base64"])
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

## Chunking

The second step is to split the extracted text into smaller chunks, which will later be embedded and retrieved efficiently. 

In this exercise, we use a simple heuristic approach: the text is split iteratively—first by heading levels (`#`), then by line breaks (`\n`), and finally by sentence (`.`). Splitting only occurs if the resulting chunk exceeds a predefined length.

**Images are treated as separate chunks**, but with a different `DataType`. Additional relevant metadata can also be included, such as the image's position relative to the text or its caption, if available. They are stored in another list.

Each chunk is enriched with metadata, including:  
- **Source file** – The document from which the chunk originates.  
- **Chunk counter** – The position of the chunk within the file.  
- **Unique identifier (`chunk_id`)** – Ensures each chunk can be referenced independently.  
- **Data type** - The document type (image or text).

```python
class DataType(str, Enum):
    TEXT = "text"
    IMAGE = "image"


class Chunk(BaseModel):
    chunk_id: int
    content: str
    metadata: dict = Field(default_factory=dict)
    data_type: Optional[DataType] = None
    score: Optional[float] = None
```


In [None]:
chunker = SimpleChunker()
text_chunks = chunker.chunk_text(text, {"source_text": example_pdf_file})
image_chunks = chunker.chunk_images(images, {"source_text": example_pdf_file})

In [None]:
print(len(text_chunks))
text_chunks[0]

In [None]:
print(len(image_chunks))

img_data = base64.b64decode(image_chunks[2].content)
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

## Embedding Models  

Once the text and images are divided into chunks, each chunk is converted into a numerical representation (embedding) that captures its meaning.  

For text, we use OpenAI’s `text-embedding-3-large`.

For images, we utilize `VLM2Vec`. Similar to text embeddings, various options exist for image embeddings, each with its own trade-offs.

In [None]:
_ = compute_openai_large_embedding_cost(text_chunks, verbose=True)

text_embedding_model = OpenAITextEmbeddings()
text_embeddings = text_embedding_model.get_embedding(
    [chunk.content for chunk in text_chunks]
)

print(text_embeddings.shape)
text_embeddings[0]

In [None]:
image_embeddings = []

image_embedding_model = VLM2VecImageEmbeddings()
for chunk in tqdm(image_chunks):
    image_embeddings.append(image_embedding_model.get_embedding(chunk.content))


image_embeddings = np.array(image_embeddings)

In [None]:
# Also define the text embedding for the image-text embedding model
image_text_embedding_model = VLM2VecTextEmbeddings()

## Vector Store and Retrieval  

Once the chunks are embedded, they must be stored in a way that allows efficient retrieval. In this exercise, we use `ChromaDB`.  

Text and image embeddings are stored separately, requiring a distinct `top_k` value for each during retrieval. Since the models used for text and image embeddings differ, their similarities cannot be directly compared. Additionally, while sparse search is not available for images, metadata filtering can still be applied.

In [None]:
vector_store_text = ChromaDBVectorStore(text_vector_store_collection)
vector_store_text.insert_documents(text_chunks, text_embeddings)

In [None]:
vector_store_image = ChromaDBVectorStore(image_vector_store_collection)
vector_store_image.insert_documents(image_chunks, image_embeddings)

In [None]:
retriever = VectorStoreRetriever(
    text_embedding_model,
    vector_store_text,
    image_text_embedding_model,
    vector_store_image,
)

results = retriever.retrieve(test_question, top_k_text=10, top_k_image=5)

In [None]:
for result_l in results:
    for result in result_l:
        if result["chunk"].data_type == DataType.TEXT:
            print(result)
        elif result["chunk"].data_type == DataType.IMAGE:
            print(f"Chunk ID: {result['chunk_id']} | Score: {result['score']}")
            img_data = base64.b64decode(result["chunk"].content)
            img = Image.open(io.BytesIO(img_data))
            plt.imshow(img)
            plt.axis("off")
            plt.show()

## LLM  

The LLM is the core of the RAG system, responsible for generating responses based on the retrieved information. In this case a **multi-modal LLM is required**, we use `gpt-4o-mini`.  

This LLM expects input in the form of a list of messages, where each message includes the content and the role of the speaker (e.g., system, user, assistant).  

Images can be provided to this LLM as `base64`, but only when the role is set to `user`.

Here is how messages are defined here:

```python
class Roles(str, Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"

class LLMMessage(BaseModel):
    content: Optional[str] = None
    role: Optional[Roles] = None
```

In [None]:
llm = OpenAILLM(temperature=0.5)

In [None]:
img_data = base64.b64decode(image_chunks[2].content)
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
answer, cost = llm.generate(
    [
        {
            "role": Roles.USER,
            "content": [
                {"type": "text", "text": test_question},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{image_chunks[2].content}"
                    },
                },
            ],
        },
    ],
    verbose=True,
)

In [None]:
print(answer.content)

## Generator  

Once the LLM is set up, a specific prompt needs to be defined for the RAG system. This prompt must include the retrieved chunks as context. The prompt has to be adapted to each specific project.

In [None]:
default_system_prompt = """You are a helpful assistant, and your task is to answer questions using relevant documents and images. Please first think step-by-step by mentioning which documents you used and then answer the question. Organize your output in a json formatted as dict{"step_by_step_thinking": Str(explanation), "document_used": List(integers), "answer": Str{answer}}. Your responses will be read by someone without specialized knowledge, so please have a definite and concise answer."""
print(default_system_prompt)

In [None]:
default_rag_template = """
Here are the relevant DOCUMENTS:
{context}

--------------------------------------------

Here is the USER QUESTION:
{query}

--------------------------------------------

Please think step-by-step and generate your output in json:
"""
print(default_rag_template)

In [None]:
generator = Generator(llm, default_system_prompt, default_rag_template)

In [None]:
answer, cost = generator.generate(
    history=[],
    query=test_question,
    chunks=[
        Chunk(
            chunk_id=0,
            data_type=DataType.IMAGE,
            content=image_chunks[2].content,
            metadata={},
        ),
        Chunk(
            chunk_id=1,
            data_type=DataType.TEXT,
            content=text_chunks[0].content,
            metadata={},
        ),
    ],
    verbose=True,
)

In [None]:
print(answer.content)

## RAG Tools  

There are several methods to improve the efficiency of a RAG pipeline.

In this notebook, we implement **query expansion** to enhance retrieval and apply **reciprocal rank fusion** to optimize the ranking of chunks when multiple queries are involved.

In [None]:
query_expansion_system_message = {
    "role": "system",
    "content": "You are a focused assistant designed to generate multiple, relevant search queries based solely on a single input query. Your task is to produce a list of these queries in English, without adding any further explanations or information.",
}

query_expansion_template_query = """
        Generate multiple search queries related to: {query}, and translate them in english if they are not already in english. Only output {expansion_number} queries in english.
        OUTPUT ({expansion_number} queries):
    """

In [None]:
answer, cost = query_expansion(
    test_question,
    llm,
    query_expansion_system_message,
    query_expansion_template_query,
    expansion_number=5,
)

answer

## RAG  

Finally, the RAG pipeline is defined by integrating all the previously discussed components into a unified process.

In [None]:
rag = DefaultRAG(
    llm,
    text_embedding_model,
    vector_store_text,
    generator,
    query_expansion_system_message,
    query_expansion_template_query,
    {"top_k_text": 5, "top_k_image": 3, "number_query_expansion": 0},
    image_text_embedding_model,
    vector_store_image,
)

In [None]:
print(test_question)

In [None]:
answer, sources, cost = rag.execute(test_question, {}, verbose=True)

In [None]:
print(json.dumps(answer, indent=3))

In [None]:
# The documents retrieved by the retriever:
print(len(sources))
print(sources[0])

In [None]:
print(cost)

# Exercises

The different blocks are redefined below, and a new pipeline is created that uses both PDFs.

1. Quickly go through the code and the above notebook to ensure you understand how each block works, focus on the how the images are handled in the pipeline.
2. Try to formulate a question about another plot in `Explainable_machine_learning_prediction_of_edema_a.pdf` that could only be explained using it, and not the text. Analyze the answer and verify it uses the image, try the same when not providing the images to the RAG.
3. Do the same for `Modeling tumor size dynamics based on real‐world electronic health records.pdf`, verify that the images retrieved indeed belong to it.
4. Discuss how the pipeline could be improved to achieve better answers and identify the current pain-points. How will it be different if using a different architecture of multi-modal RAG? If time permits, implement those changes.

In [None]:
data_extractor = PDFExtractorAPI()
chunker = SimpleChunker(max_chunk_size=1000)


text_chunks = []
image_chunks = []

for pdf_file in pdf_files:
    print(pdf_file)
    pdf_path = os.path.join(data_folder, pdf_file)
    _, text, images = data_extractor.extract_text_and_images(pdf_path)
    text_chunks_curr = chunker.chunk_text(text, {"source_text": pdf_file})
    image_chunks_curr = chunker.chunk_images(images, {"source_text": pdf_file})
    text_chunks.extend(text_chunks_curr)
    image_chunks.extend(image_chunks_curr)

print(len(text_chunks))
print(len(image_chunks))

In [None]:
text_embedding_model = OpenAITextEmbeddings()
text_embeddings = text_embedding_model.get_embedding(
    [chunk.content for chunk in text_chunks]
)
print(text_embeddings.shape)

In [None]:
image_embeddings = []

image_embedding_model = VLM2VecImageEmbeddings()
for chunk in tqdm(image_chunks):
    image_embeddings.append(image_embedding_model.get_embedding(chunk.content))

image_embeddings = np.array(image_embeddings)

image_text_embedding_model = VLM2VecTextEmbeddings()

In [None]:
vector_store_text = ChromaDBVectorStore(text_vector_store_full_collection)
vector_store_text.insert_documents(text_chunks, text_embeddings)

vector_store_image = ChromaDBVectorStore(image_vector_store_full_collection)
vector_store_image.insert_documents(image_chunks, image_embeddings)

In [None]:
retriever = VectorStoreRetriever(
    text_embedding_model,
    vector_store_text,
    image_text_embedding_model,
    vector_store_image,
)

results = retriever.retrieve(test_question, top_k_text=10, top_k_image=5)

In [None]:
llm = OpenAILLM(temperature=0.3)

In [None]:
system_prompt = """You are a helpful assistant, and your task is to answer questions using relevant documents and images. Please first think step-by-step by mentioning which documents you used and then answer the question. Organize your output in a json formatted as dict{"step_by_step_thinking": Str(explanation), "document_used": List(integers), "answer": Str{answer}}. Your responses will be read by someone without specialized knowledge, so please have a definite and concise answer."""
print(system_prompt)

In [None]:
rag_template = """
Here are the relevant DOCUMENTS:
{context}

--------------------------------------------

Here is the USER QUESTION:
{query}

--------------------------------------------

Please think step-by-step and generate your output in json:
"""
print(rag_template)

In [None]:
query_expansion_system_message = {
    "role": "system",
    "content": "You are a focused assistant designed to generate multiple, relevant search queries based solely on a single input query. Your task is to produce a list of these queries in English, without adding any further explanations or information.",
}

query_expansion_template_query = """
        Generate multiple search queries related to: {query}, and translate them in english if they are not already in english. Only output {expansion_number} queries in english.
        OUTPUT ({expansion_number} queries):
    """

In [None]:
generator = Generator(llm, system_prompt, rag_template)

In [None]:
rag = DefaultRAG(
    llm,
    text_embedding_model,
    vector_store_text,
    generator,
    query_expansion_system_message,
    query_expansion_template_query,
    {"top_k_text": 5, "top_k_image": 3, "number_query_expansion": 0},
    image_text_embedding_model,
    vector_store_image,
)

In [None]:
answer, sources, cost = rag.execute(
    "Here goes my amazing question!",
    {},
    verbose=True,
)

In [None]:
# The documents retrieved by the retriever:
print(len(sources))
print(sources[0])

In [None]:
print(json.dumps(answer, indent=3))

In [None]:
print(cost)

----------------