# Multimodal RAG with Weaviate

Search and generate responses using both text and images. We'll use PDF pages from Reddit's S-1 filing as our multimodal data.

![images/multimodal_rag.png](images/multimodal_rag.png)

In [None]:
%pip install -Uqq weaviate-client==v4.17.0-rc1

In [None]:
# Refresh credentials & load the Weaviate IP
from helpers import update_creds

AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_SESSION_TOKEN = update_creds()

%store -r WEAVIATE_IP

In [None]:
import weaviate
import os

client = weaviate.connect_to_local(
    WEAVIATE_IP,
    headers = {
        "X-AWS-Access-Key": AWS_ACCESS_KEY,
        "X-AWS-Secret-Key": AWS_SECRET_KEY,
        "X-AWS-Session-Token": AWS_SESSION_TOKEN,
    }
)

client.is_ready()

Run the following if you need to delete an existing collection and start fresh:

In [None]:
client.collections.delete("Pages")

### Data import

Create a collection to store PDF pages as images with embeddings.

We'll set up a collection where each object is based on a page of a PDF document. 

The vector configuration is set up with: `Configure.MultiVectors.self_provided`

This specifies that it is a multi-vector embedding model (`Configure.MultiVectors`), and that the vectors will be provided by the user (`self-provided`).

In [None]:
from weaviate.classes.config import Property, DataType, Configure, Tokenization

client.collections.create(
    name="Pages",
    properties=[
        Property(
            name="document_title",
            data_type=DataType.TEXT,
        ),
        Property(
            name="page_image",
            data_type=DataType.BLOB,
        ),
        Property(
            name="filename",
            data_type=DataType.TEXT,
            tokenization=Tokenization.FIELD
        ),
    ],
    vector_config=[
        Configure.Vectors.multi2vec_aws(
            name="page",
            image_fields=["page_image"],
            region="us-west-2",
            model="amazon.titan-embed-image-v1"
        )
    ]
)

Now we can load the data into the collection.

In [None]:
pages = client.collections.use("Pages")

Load images into the collection with automatic embedding generation.

In [None]:
try:
    import pymupdf
except ImportError:
    %pip install -Uqq pymupdf

In [None]:
%%bash
python pdf_to_img.py hai*.pdf
echo "Images extracted from AI Report PDF"

In [None]:
from tqdm import tqdm
from pathlib import Path
import base64
from weaviate.util import generate_uuid5


img_files = sorted(Path("data/imgs").glob("*.jpg"))

with pages.batch.fixed_size(batch_size=10) as batch:
    for filepath in tqdm(img_files[:100]):
        image = filepath.read_bytes()
        base64_image = base64.b64encode(image).decode('utf-8')
        obj = {
            "document_title": "HAI report",
            "page_image": base64_image,
            "filename": filepath.name
        }

        # Add object to batch for import with (batch.add_object())
        batch.add_object(
            properties=obj,
            uuid=generate_uuid5(filepath.name)
        )

### Queries

Search through images using text queries.

Find the most relevant pages using semantic search.

Since we have the embedding to query with (`query_embedding`), we can use the `near_vector` method to find the most relevant pages in the collection.

In [None]:
response = pages.query.near_text(
    query="RAG",
    limit=2,
)

for o in response.objects:
    print(f"Filename: {o.properties['filename']}")

Display the retrieved images to see what was found.

In [None]:
def display_imgs(images_to_display):
    import matplotlib.pyplot as plt
    from PIL import Image

    fig, axes = plt.subplots(1, 2, figsize=(30, 40))

    for i, img_path in enumerate(images_to_display):
        img = Image.open(img_path)
        axes[i].imshow(img)
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()


images = [
    f"data/imgs/" + o.properties['filename'] for o in response.objects
]
display_imgs(images)

In [None]:
response = pages.query.near_text(
    query="self-driving cars",
    limit=2,
)

for o in response.objects:
    print(f"Filename: {o.properties['filename']}")

In [None]:
images = [
    f"data/imgs/" + o.properties['filename'] for o in response.objects
]
display_imgs(images)

### Retrieval augmented generation (RAG)

Generate text responses based on image content.

Combine image retrieval with AI text generation for detailed analysis.

In [None]:
from weaviate.classes.generate import GenerativeConfig, GenerativeParameters

prompt = GenerativeParameters.grouped_task(
    prompt="What does this say about self-driving cars?",
    image_properties=["page_image"]  # Property containing images in Weaviate
)

gen_config_aws = GenerativeConfig.aws(
    region="us-west-2",
    service="bedrock",
    model="us.amazon.nova-pro-v1:0"
)

# We use `pages.generate` here to generate a response based on the retrieved pages.
response = pages.generate.near_text(
    query="self-driving cars",
    limit=2,
    # These parameters are used to define the RAG task & model
    grouped_task=prompt,
    generative_provider=gen_config_aws
)

And the results are:

In [None]:
print(response.generative.text)

In [None]:
from weaviate.classes.generate import GenerativeConfig, GenerativeParameters

prompt = GenerativeParameters.grouped_task(
    prompt="What do each of these pages describe?",
    image_properties=["page_image"]  # Property containing images in Weaviate
)

# We use `pages.generate` here to generate a response based on the retrieved pages.
response = pages.generate.near_text(
    query="advances in RAG",
    limit=3,
    # These parameters are used to define the RAG task & model
    grouped_task=prompt,
    generative_provider=gen_config_aws
)

print(response.generative.text)

### In-depth research & analysis

Use multimodal RAG for detailed document analysis.

In [None]:
from weaviate.classes.generate import GenerativeConfig, GenerativeParameters

prompt = GenerativeParameters.grouped_task(
    prompt="What do these pages highlight about the recent advances in RAG?",
    image_properties=["page_image"]  # Property containing images in Weaviate
)

# We use `pages.generate` here to generate a response based on the retrieved pages.
response = pages.generate.near_text(
    query="advances in RAG",
    limit=3,
    # These parameters are used to define the RAG task & model
    grouped_task=prompt,
    generative_provider=gen_config_aws
)

In [None]:
print(response.generative.text)
for o in response.objects:
    print(o.properties["filename"])

### Close the client

Always close your connection when finished.

In [None]:
client.close()