# Multi-Modal RAG Hands-On Theory Notebook

## What is RAG?

![RAG Image](images/rag.png)

Retrieval-Augmented Generation (RAG) is an AI technique that combines information retrieval with text generation. Instead of relying solely on a pre-trained language model’s internal knowledge, RAG dynamically retrieves relevant documents from an external knowledge base before generating a response.

## Why Rag?

![Why RAG Image](images/why_rag.png)

1. **Improved Accuracy:** RAG enhances the factual correctness of generated responses by retrieving up-to-date and domain-specific information, reducing the likelihood of hallucinations (fabricated information).

2. **Better Generalization:** Since RAG dynamically retrieves relevant documents, it performs well across various domains without requiring extensive fine-tuning, making it more adaptable to new topics.

3. **Reduced Model Size Requirements:** Instead of embedding all knowledge within a large model, RAG leverages external databases, allowing for smaller, more efficient models while maintaining high-quality responses.

4. **Enhanced Explainability:** By referencing retrieved documents, RAG provides verifiable sources for its answers, making it more transparent and easier to trust compared to purely generative models.

5. **And more...**

## What is Multi-modal RAG?

Multi-modal Retrieval-Augmented Generation (RAG) extends the standard RAG approach by incorporating **multiple types of data**—such as text, images, or even audio—**into the retrieval and generation process**. Instead of working with a single modality (like just text), multi-modal RAG systems can query and generate content based on various forms of input, allowing for richer and more diverse responses. 

For example, when dealing with a document that contains both images and text, a multi-modal RAG system can retrieve relevant images along with the associated text, enhancing the quality and relevance of the generated response.


### Text+Images RAG

Multi-modal RAG (Retrieval-Augmented Generation) systems differ in how they handle text and images, depending on whether both the database and the language model (LLM) work with text and images together or focus on one modality. There are a lot of different approaches to reach this goal.

![Multi-Modal RAG Image](images/multi_modal_rag.jpg)

Another approach not shown above is to consider the file as a series of images:
- **File-as-Images** → **Image DB** → **Retrieve Images** → **Multi-modal LLM** → **Text answer and sources**  
- In this case, the document is converted into a series of images (e.g., scanned pages), stored in an image-specific database, and sent as images to a multi-modal LLM, which generates textual answers based on the content of the images. The DB used is usually specialized for this task.

# Hands-on Example

In this exercise, you will learn how to implement a **Multimodal Retrieval-Augmented Generation (RAG)** pipeline from scratch, without relying on tools like `langchain`.  While `langchain` is a powerful framework that simplifies the development of RAG pipelines, it can sometimes lack flexibility for custom implementations, as it abstracts many components.

Here, two different vector store are used to **store the images and text separately**.

The different components of the pipeline are:

- **Text and image extraction from PDFs** – Extract raw text and images from PDF files to make the content processable.  
- **Text and image chunking** – Break the extracted text and images into smaller, meaningful segments to improve retrieval efficiency.  
- **Embedding of the chunks (text and images)** – Convert text and image chunks into numerical representations (embeddings) using pre-trained models.  
- **Storage of the embeddings in a vector store** – Save both text and image embeddings in a specialized database (vector store) to enable fast similarity searches.  
- **Relevant chunks retrieval** – Query the vector store to find the most relevant text and image chunks based on user input.  
- **Setting and prompting of the LLM for a RAG** – Structure prompts and configure the language model to integrate retrieved text and image information into its responses.  
- **Additional tools for improved retrieval** – Use techniques like query expansion to reformulate user queries for better recall and reciprocal rank fusion to combine results from multiple retrieval methods.  
- **Final multimodal RAG pipeline implementation** – Integrate all components into a complete system that retrieves relevant information (both text and images) and generates enhanced responses using the language model.

**Note:** To complete this exercise, you need an OpenAI API key, the PDF files with images, and the necessary libraries installed (see `requirements.txt`).

The example is applied to `Explainable_machine_learning_prediction_of_edema_a.pdf`. Please, have a quick look at it before starting the exercise.

We will try to answer the following question:

In [None]:
test_question = "According to SHAP analysis, which factors were the most influential in predicting higher-grade edema (Grade 2+)?"

## Setup

In [None]:
import sys

sys.path.append("../../")

In [None]:
!pip install -r ../../requirements.txt

In [None]:
import io
import os
import getpass
import json
from tqdm import tqdm

import numpy as np

import base64
import matplotlib.pyplot as plt
from PIL import Image

from helpers.constants_and_data_classes import Chunk, DataType
from helpers.data_processing import PDFExtractor, SimpleChunker
from helpers.embedding import (
    OpenAITextEmbeddings,
    OpenAITextEmbeddingsAzure,
    ImageEmbeddings,
    ImageEmbeddingsForText,
    compute_openai_large_embedding_cost,
)
from helpers.vectorstore import (
    ChromaDBVectorStore,
    VectorStoreRetriever,
)
from helpers.llm import OpenAILLM, OpenAILLMAzure
from helpers.rag import Generator, DefaultRAG

In [None]:
data_folder = "../../data"

example_pdf_file = "Explainable_machine_learning_prediction_of_edema_a.pdf"
example_pdf_path = os.path.join(data_folder, example_pdf_file)

text_vector_store_collection = "text_collection"
image_vector_store_collection = "image_collection"

text_vector_store_full_collection = "text_collection_full"
image_vector_store_full_collection = "image_collection_full"

In [None]:
os.environ["OPENAI_API_KEY"] = getpass.getpass()

## If Azure Endpoint then you don't need the OPENAI_API_KEY but the following
# os.environ["AZURE_API_KEY"] = ""
# os.environ["AZURE_API_BASE"] = ""
# os.environ["AZURE_API_VERSION"] = ""

## LLM  

The LLM is the core of the RAG system, responsible for generating responses based on the retrieved information. There are many options available on-premise or online, each with different performance, speed, specialized knowledge and cost trade-offs.

In this case a **multi-modal LLM is required**, we use `gpt-4o-mini`.  

This LLM expects input in the form of a list of messages, where each message includes the content and the role of the speaker (e.g., developer, user, assistant).  

Images can be provided to this LLM as `base64`, but only when the role is set to `user`.

Here is how messages are defined here:

```python
class Roles(str, Enum):
    DEVELOPER = "developer" # Previously, system
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"

class LLMMessage(BaseModel):
    content: Optional[str] = None
    role: Optional[Roles] = None
```

In [None]:
# Check if both Azure environment variables exist
azure_endpoint = os.getenv("AZURE_API_BASE")
azure_api_key = os.getenv("AZURE_API_KEY")
if azure_endpoint and azure_api_key:
    llm = OpenAILLMAzure(temperature=0.5)
    print("Using AzureOpenAI client")
else:
    llm = OpenAILLM(temperature=0.5)
    print("Using OpenAI client")

In [None]:
print(test_question)

In [None]:
answer, price = llm.generate([{"role": "user", "content": test_question}], verbose=True)

In [None]:
print(answer.content)

## PDF Text and Images Extraction  

The first step in the pipeline is to extract text and images from the document.  

In this exercise, we use the `MinerU` library, which under the hood uses among others `doclayout_yolo` for segmentation. Note that this model is not commercially permissive.

The choice of extraction tool should be carefully considered. Depending on the document type and formatting, different methods may be required to preserve text integrity and leverage structural elements such as headings, tables, or metadata for better processing (`pdfplumber` (better for tables), `Tesseract OCR` (for scanned PDFs), ect.).

Extracting images can be challenging, as **irrelevant images** (such as logos) are often included, and some images may be **split into multiple images**. It may also be helpful to link the position of images to nearby text for more accurate retrieval. Specialized tools or methods might be required to efficiently handle images embedded in the document.

In [None]:
data_extractor = PDFExtractor()
_, text, images = data_extractor.extract_text_and_images(example_pdf_path)

In [None]:
print(text[:1000])

In [None]:
img_data = base64.b64decode(images[2]["image_base64"])
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

In [None]:
img_data = base64.b64decode(images[0]["image_base64"])
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

## Chunking

The second step is to split the extracted text into smaller chunks, which will later be embedded and retrieved efficiently. 

In this exercise, we use a simple heuristic approach: the text is split iteratively—first by heading levels (`#`), then by line breaks (`\n`), and finally by sentence (`.`). Splitting only occurs if the resulting chunk exceeds a predefined length. However, more advanced techniques exist, such as **semantic chunking** (which splits based on meaning rather than syntax) or **agentic chunking** (which dynamically adapts chunk sizes based on context).

**Images are treated as separate chunks**, but with a different `DataType`. Additional relevant metadata can also be included, such as the image's position relative to the text or its caption, if available. They are stored in another list.

Each chunk is enriched with metadata, including:  
- **Source file** – The document from which the chunk originates.  
- **Chunk counter** – The position of the chunk within the file.  
- **Unique identifier (`chunk_id`)** – Ensures each chunk can be referenced independently.  
- **Data type** - The document type (image or text).

Additional metadata could be included to enable more refined filtering and retrieval strategies.  

```python
class DataType(str, Enum):
    TEXT = "text"
    IMAGE = "image"


class Chunk(BaseModel):
    chunk_id: int
    content: str
    metadata: dict = Field(default_factory=dict)
    data_type: Optional[DataType] = None
    score: Optional[float] = None
```


In [None]:
chunker = SimpleChunker()
text_chunks = chunker.chunk_text(text, {"source_text": example_pdf_file})
image_chunks = chunker.chunk_images(images, {"source_text": example_pdf_file})

In [None]:
print(len(text_chunks))
text_chunks[0]

In [None]:
print(len(image_chunks))

img_data = base64.b64decode(image_chunks[2].content)
img = Image.open(io.BytesIO(img_data))

plt.imshow(img)
plt.axis("off")
plt.show()

## Embedding Models  

Once the text and images are divided into chunks, each chunk is converted into a numerical representation (embedding) that captures its meaning.  

For text, we use OpenAI’s `text-embedding-3-large`, but other options exist, each with different trade-offs in on-premise vs online, accuracy, speed, and cost. The choice of model depends on the specific needs of the retrieval task.

For images, we utilize `VLM2Vec`. Similar to text embeddings, various options exist for image embeddings, each with its own trade-offs. There will be one embedding model to convert images to vector representation, and another to convert the user query into the same representation.

In [None]:
_ = compute_openai_large_embedding_cost(text_chunks, verbose=True)


# Check if both Azure environment variables exist
azure_endpoint = os.getenv("AZURE_API_BASE")
azure_api_key = os.getenv("AZURE_API_KEY")
if azure_endpoint and azure_api_key:
    text_embedding_model = OpenAITextEmbeddingsAzure()
    print("Using AzureOpenAI client")
else:
    text_embedding_model = OpenAITextEmbeddings()
    print("Using OpenAI client")

    
text_embeddings = text_embedding_model.get_embedding(
    [chunk.content for chunk in text_chunks]
)

print(text_embeddings.shape)
text_embeddings[0]

In [None]:
image_embeddings = []

image_embedding_model = ImageEmbeddings()
for chunk in tqdm(image_chunks):
    image_embeddings.append(image_embedding_model.get_embedding(chunk.content))


image_embeddings = np.array(image_embeddings)

In [None]:
# Also define the text embedding for the image-text embedding model
image_text_embedding_model = ImageEmbeddingsForText()

## Vector Store and Retrieval  

After embedding the chunks, they need to be stored for efficient retrieval. The choice of vector store depends on factors like accuracy, speed, and filtering options. In this exercise, we use `ChromaDB`.  

The next step is retrieving the most relevant chunks based on a query. In this implementation, the retriever uses only embeddings (sparse search). However, in some cases, dense search methods like BM25 or hybrid approaches combining both sparse and dense search can be used for better accuracy when retrieving the text. Some retrieval strategies also use the metadata.

Text and image embeddings are stored separately here, thus their similarities cannot be directly compared. As a consequence, the retrieval strategy implemented here is to take the `top_k` for each datatype.


In [None]:
vector_store_text = ChromaDBVectorStore(text_vector_store_collection)
vector_store_text.insert_chunks(text_chunks, text_embeddings)

In [None]:
vector_store_image = ChromaDBVectorStore(image_vector_store_collection)
vector_store_image.insert_chunks(image_chunks, image_embeddings)

In [None]:
retriever = VectorStoreRetriever(
    text_embedding_model,
    vector_store_text,
    image_text_embedding_model,
    vector_store_image,
)

results = retriever.retrieve(test_question, top_k_text=5, top_k_image=1)

In [None]:
for result_l in results:
    for result in result_l:
        if result["chunk"].data_type == DataType.TEXT:
            print(result)
        elif result["chunk"].data_type == DataType.IMAGE:
            print(f"Chunk ID: {result['chunk_id']} | Score: {result['score']}")
            img_data = base64.b64decode(result["chunk"].content)
            img = Image.open(io.BytesIO(img_data))
            plt.imshow(img)
            plt.axis("off")
            plt.show()

## Generator  

Once the LLM is set up, a specific prompt needs to be defined for the RAG system. This prompt must include the retrieved chunks as context. The prompt has to be adapted to each specific project.

In addition to the basic prompt, we incorporate **prompt engineering** by asking the LLM to justify its answer. The model is also instructed to indicate which chunks were most relevant in forming its response, improving **interpretability**, and to provide the answer in **JSON format** for easier data management.

In [None]:
default_developer_prompt = """You are a helpful assistant, and your task is to answer questions using relevant chunks and images. Please first think step-by-step by mentioning which chunks you used and then answer the question. Organize your output in a json formatted as dict{"step_by_step_thinking": Str(explanation), "chunk_used": List(integers), "answer": Str{answer}}. Your responses will be read by someone without specialized knowledge, so please have a definite and concise answer."""
print(default_developer_prompt)

In [None]:
default_rag_template = """
Here are the relevant CHUNKS:
{context}

--------------------------------------------

Here is the USER QUESTION:
{query}

--------------------------------------------

Please think step-by-step and generate your output in json:
"""
print(default_rag_template)

In [None]:
generator = Generator(llm, default_developer_prompt, default_rag_template)

In [None]:
answer, cost = generator.generate(
    history=[],
    query=test_question,
    chunks=[
        Chunk(
            chunk_id=0,
            data_type=DataType.IMAGE,
            content=image_chunks[2].content,
            metadata={},
        ),
        Chunk(
            chunk_id=1,
            data_type=DataType.TEXT,
            content=text_chunks[0].content,
            metadata={},
        ),
    ],
    verbose=True,
)

In [None]:
print(answer.content)

In [None]:
print(cost)

## RAG "Tricks"  

There are several methods to improve the efficiency of a RAG pipeline, such as query contextualization, query reformulation, re-ranking, query expansion, etc. For the sake of time, none of those has been implemented here.

## RAG  

Finally, the RAG pipeline is defined by integrating all the previously discussed components into a unified process.

## Only with text

In [None]:
rag_without_images = DefaultRAG(
    llm=llm,
    text_embedding_model=text_embedding_model,
    text_vector_store=vector_store_text,
    generator=generator,
    params={"top_k_text": 5},
)

In [None]:
print(test_question)

In [None]:
answer, sources, cost = rag_without_images.execute(test_question, {}, verbose=True)

In [None]:
print(json.dumps(answer, indent=3))

In [None]:
# The chunks retrieved by the retriever:
print(len(sources))
print(sources[0])

In [None]:
print(cost)

## With text and images

In [None]:
rag = DefaultRAG(
    llm=llm,
    text_embedding_model=text_embedding_model,
    text_vector_store=vector_store_text,
    image_text_embedding_model=image_text_embedding_model,
    image_vector_store=vector_store_image,
    generator=generator,
    params={"top_k_text": 5, "top_k_image": 1},
)

In [None]:
print(test_question)

In [None]:
answer, sources, cost = rag.execute(test_question, {}, verbose=True)

In [None]:
print(json.dumps(answer, indent=3))

In [None]:
# The chunks retrieved by the retriever:
print(len(sources))
print(sources[0])

In [None]:
print(cost)

----------------