<a href="https://colab.research.google.com/github/tuhinmallick/AI-for-Fashion/blob/main/Multimodal_RAG_on_Your_Computer_with_ColPali_and_Qwen2_VL_for_PDF_Documents.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# Step 1: Simulate the graphqlCache data with two products (replace this with your actual data)
graphqlCache = {
    'product_1': {
        'data': {
            'product': {
                'id': 'ern:product::STH51L1F3-Q11',
                'sku': 'STH51L1F3-Q11',
                'name': 'SET OF 2 - Haar-Styling-Accessoires - black',
                'navigationTargetGroup': 'WOMEN',
                'silhouette': 'JEWELLERY',
                'supplierName': 'SET OF 2',
                'shortDescription': None,
                'brandRestriction': {'themeOverride': None},
                'brand': {'name': 'Stradivarius'},
                'displayPrice': {
                    'trackingCurrentAmount': 12.99,
                    'trackingDiscountAmount': None,
                    'original': {
                        'formatted': '12,99 €',
                        'amount': 1299,
                        'currency': 'EUR'
                    },
                    'displayMode': 'BLACK_PRICE',
                    'priceDisplayTracking': {'id': 'BLACK_PRICE'}
                },
                'simples': [{'size': 'One Size', 'sku': 'STH51L1F3-Q110ONE000', 'available': True}],
                'defaultMediaInfo': {'alternativeText': 'Stradivarius SET OF 2 - Haar-Styling-Accessoires - black'},
                'smallDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/287a02f4c9c145ffa4a14f22fddaa88b.jpg?imwidth=300'},
                'mediumDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/287a02f4c9c145ffa4a14f22fddaa88b.jpg?imwidth=400'},
                'largeDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/287a02f4c9c145ffa4a14f22fddaa88b.jpg?imwidth=500'},
                'extraLargeDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/287a02f4c9c145ffa4a14f22fddaa88b.jpg?imwidth=780'},
                'uri': 'https://www.zalando.de/stradivarius-haar-styling-accessoires-black-sth51l1f3-q11.html'
            }
        }
    },
    'product_2': {
        'data': {
            'product': {
                'id': 'ern:product::G2U31G02H-S11',
                'sku': 'G2U31G02H-S11',
                'name': 'GARNIER 12 DAYS ADVENT CALENDAR 2023 - Adventskalender - -',
                'navigationTargetGroup': 'WOMEN',
                'silhouette': 'SETS_AND_PALETTES',
                'supplierName': 'GARNIER 12 DAYS ADVENT CALENDAR 2023',
                'brand': {'name': 'Garnier'},
                'displayPrice': {
                    'trackingCurrentAmount': 13.00,
                    'trackingDiscountAmount': 6.95,
                    'original': {
                        'formatted': '19,95 €',
                        'amount': 1995,
                        'currency': 'EUR',
                        'discountLabel': '-35%'
                    },
                    'displayMode': 'TWO_PRICES',
                    'priceDisplayTracking': {'id': 'TWO_PRICES'}
                },
                'simples': [{'size': 'One Size', 'sku': 'G2U31G02H-S110ONE000', 'available': True}],
                'defaultMediaInfo': {'alternativeText': 'Garnier GARNIER 12 DAYS ADVENT CALENDAR 2023 - Adventskalender - -'},
                'smallDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/8302af8513d340c59817e81eef6114a4/e183d4365dc9416594820a80073a62fc.jpg?imwidth=300'},
                'mediumDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/8302af8513d340c59817e81eef6114a4/e183d4365dc9416594820a80073a62fc.jpg?imwidth=400'},
                'largeDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/8302af8513d340c59817e81eef6114a4/e183d4365dc9416594820a80073a62fc.jpg?imwidth=500'},
                'extraLargeDefaultMedia': {'uri': 'https://img01.ztat.net/article/spp-media-p1/8302af8513d340c59817e81eef6114a4/e183d4365dc9416594820a80073a62fc.jpg?imwidth=780'},
                'uri': 'https://www.zalando.de/garnier-adventskalender-garnier-tuchmasken-offline-neu-adventskalender-g2u31g02h-s11.html'
            }
        }
    }
}

# Initialize an array to store the parsed products
products = []

# Function to calculate the price group based on a given price and step size
def calculate_price_group(price, step_size=20):
    if price < 0:
        return 'Unknown'
    lower_bound = (price // step_size) * step_size
    upper_bound = lower_bound + step_size
    return f'{lower_bound}-{upper_bound}'

# Parse all products in graphqlCache
for key, cache_entry in graphqlCache.items():
    product_data = cache_entry.get('data', {}).get('product', {})
    if not product_data:
        continue  # Skip if product data is missing

    # Extract and assign variables with required names
    product_id = product_data.get('id')
    product_handle = product_data.get('uri')
    product_title = product_data.get('name')
    product_vendor = product_data.get('supplierName')
    product_type = product_data.get('silhouette')
    product_description = product_data.get('shortDescription', 'N/A')
    product_brand = product_data.get('brand', {}).get('name')

    # Price extraction
    display_price = product_data.get('displayPrice', {})
    product_original_price = display_price.get('original', {}).get('amount', 0) / 100  # convert cents to euros
    product_discount_price = display_price.get('trackingCurrentAmount', 0)
    product_discount_amount = display_price.get('trackingDiscountAmount', 0)
    product_currency = display_price.get('original', {}).get('currency', 'EUR')
    product_price_group = calculate_price_group(product_original_price)

    # Product deal: If there is a discount
    # Product deal: If there is a discount
    product_deal = None
    if product_discount_amount is not None and product_discount_amount > 0: # Check if product_discount_amount is not None before comparing
        product_deal = {
            'discount_amount': product_discount_amount,
            'discount_percentage': f"{(product_discount_amount / product_original_price) * 100:.2f}%"
        }

    # Sizes and Variants
    variants = product_data.get('simples', [])
    product_sizes = []
    for variant in variants:
        size = variant.get('size', 'N/A')
        sku = variant.get('sku')
        available = variant.get('available', True)
        product_sizes.append({
            'size': size,
            'sku': sku,
            'price': product_discount_price,
            'compare_at_price': product_original_price if product_original_price > product_discount_price else None,
            'available': available
        })

    # Image URLs
    product_page_image_url = product_data.get('defaultMediaInfo', {}).get('alternativeText', 'No image available')
    product_packshot_images = {
        'product_packshot_image_small': product_data.get('smallDefaultMedia', {}).get('uri', None),
        'product_packshot_image_medium': product_data.get('mediumDefaultMedia', {}).get('uri', None),
        'product_packshot_image_large': product_data.get('largeDefaultMedia', {}).get('uri', None),
        'product_packshot_image_xlarge': product_data.get('extraLargeDefaultMedia', {}).get('uri', None)
    }

    # Construct product object
    product_obj = {
        'product_id': product_id,
        'product_handle': product_handle,
        'product_title': product_title,
        'product_vendor': product_vendor,
        'product_type': product_type,
        'product_tags': [],  # Assuming no tags in the data
        'product_description': product_description,
        'product_brand': product_brand,
        'product_original_price': product_original_price,
        'product_discount_price': product_discount_price,
        'product_discount_amount': product_discount_amount,
        'product_currency': product_currency,
        'product_price_group': product_price_group,
        'product_deal': product_deal,
        'product_sizes': product_sizes,
        'product_page_image_url': product_page_image_url,
        'product_packshot_images': product_packshot_images,
        'product_url': product_handle  # Same as product_handle
    }

    # Append to products list
    products.append(product_obj)

# Output the parsed products
for product in products:
    print(product)


{'product_id': 'ern:product::STH51L1F3-Q11', 'product_handle': 'https://www.zalando.de/stradivarius-haar-styling-accessoires-black-sth51l1f3-q11.html', 'product_title': 'SET OF 2 - Haar-Styling-Accessoires - black', 'product_vendor': 'SET OF 2', 'product_type': 'JEWELLERY', 'product_tags': [], 'product_description': None, 'product_brand': 'Stradivarius', 'product_original_price': 12.99, 'product_discount_price': 12.99, 'product_discount_amount': None, 'product_currency': 'EUR', 'product_price_group': '0.0-20.0', 'product_deal': None, 'product_sizes': [{'size': 'One Size', 'sku': 'STH51L1F3-Q110ONE000', 'price': 12.99, 'compare_at_price': None, 'available': True}], 'product_page_image_url': 'Stradivarius SET OF 2 - Haar-Styling-Accessoires - black', 'product_packshot_images': {'product_packshot_image_small': 'https://img01.ztat.net/article/spp-media-p1/287a02f4c9c145ffa4a14f22fddaa88b.jpg?imwidth=300', 'product_packshot_image_medium': 'https://img01.ztat.net/article/spp-media-p1/287a02f

*More details in this article: [Multimodal RAG with ColPali and Qwen2-VL on Your Computer](https://newsletter.kaitchup.com/p/multimodal-rag-with-colpali-and-qwen2)*

This notebook shows how to run a multimodal RAG system using ColPali, wrapped with Byaldi, and Qwen2-VL.

It uses a PDF documet (a scientific paper) as a source of external knowledge.

The notebook consumes around 27 GB of GPU VRAM. If you want to run it on a consumer GPU with less than 24 GB of VRAM, you can use the quantized version of Qwen2-VL-7B-Instruct (GPTQ or AWQ) or use Qwen2-VL-2B-Instruct. You will find them in the Qwen2-VL collection.

I really recommend a GPU supporting FlashAttention, i.e., an Ampere or a more recent GPU. Otherwise, it will consume much more memory.

I used code published from two GitHub repositories:

* [Byaldi](https://github.com/AnswerDotAI/byaldi)
* [Smol-vision](https://github.com/merveenoyan/smol-vision/blob/main/ColPali_%2B_Qwen2_VL.ipynb)





#Setting Up your Environment for ColPali and Qwen2-VL

* I needed to install Transformers from source for running Qwen2-VL.

* FlashAttention (flash_attn) isn’t mandatory but without it, you won’t be able to run Qwen2-VL 2B on a consumer GPU (24 GB).

* Byaldi makes very easy the use of ColPali for RAG. It’s a wrapper of the ColPali repository.

* qwen_vl_utils is made by the Qwen team to facilitate the processing of the input for Qwen VL.

* pdf2image is used to convert the PDF pages into images. This is necessary since Qwen2-VL can’t encode PDF files.

* poppler-utils is required by Byaldi to index the PDF pages.

In [None]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git byaldi accelerate flash-attn qwen_vl_utils pdf2image
!sudo apt-get install -y poppler-utils

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-1caauxrq
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-1caauxrq
  Resolved https://github.com/huggingface/transformers.git to commit 8f8af0fb38baa851f3fd69f564fbf91b5af78332
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting byaldi
  Downloading Byaldi-0.0.2.post2-py3-none-any.whl.metadata (20 kB)
Collecting accelerate
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting flash-attn
  Downloading flash_attn-2.6.3.tar.gz (2.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m74.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting qwen_

I’ve downloaded a paper from arXiv, named it "1.pdf" and put it in the directory "documents/". You can put several PDF files and images in this directory. They will all be indexed by Byaldi.

In [None]:
!mkdir documents
%cd documents
!wget https://arxiv.org/pdf/2409.06697 -O 1.pdf
%cd ..

/content/documents
--2024-09-12 16:14:20--  https://arxiv.org/pdf/2409.06697
Resolving arxiv.org (arxiv.org)... 151.101.67.42, 151.101.3.42, 151.101.195.42, ...
Connecting to arxiv.org (arxiv.org)|151.101.67.42|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4414654 (4.2M) [application/pdf]
Saving to: ‘1.pdf’


2024-09-12 16:14:21 (14.4 MB/s) - ‘1.pdf’ saved [4414654/4414654]

/content


#Load the Models for Multimodal RAG

We load ColPali with RAGMultiModalModel from byaldi.

In [None]:
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from pdf2image import convert_from_path

RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16,  attn_implementation="flash_attention_2", device_map="cuda")


Verbosity is set to 1 (active). Pass verbose=0 to make quieter.


adapter_config.json:   0%|          | 0.00/752 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/66.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/862M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/78.6M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/243k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/733 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


model.safetensors.index.json:   0%|          | 0.00/56.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/5 [00:00<?, ?it/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/244 [00:00<?, ?B/s]

#Indexing PDF Files with Byaldi

Remove "1.pdf" in the following cell if you want to index the entire "documents" directory instead.
If it works, it should output this:

In [None]:
# Optionally, you can specify an `index_root`, which is where it'll look for the index. It defaults to ".byaldi/".
RAG.index(
    input_path="documents/1.pdf", # The path to your documents
    index_name="index", # The name you want to give to your index. It'll be saved at `index_root/index_name/`.
    store_collection_with_index=False,
    overwrite=True # Whether to overwrite an index if it already exists. If False, it'll return None and do nothing if `index_root/index_name` exists.
)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Added page 1 of document 0 to index.
Added page 2 of document 0 to index.
Added page 3 of document 0 to index.
Added page 4 of document 0 to index.
Added page 5 of document 0 to index.
Added page 6 of document 0 to index.
Added page 7 of document 0 to index.
Added page 8 of document 0 to index.
Added page 9 of document 0 to index.
Added page 10 of document 0 to index.
Index exported to .byaldi/index
Index exported to .byaldi/index


{0: 'documents/1.pdf'}

ColPali will encode this query text and use the embedding to retrieve the relevant pages from the indexed document.

In [None]:
text_query = "What is the type of the star hosting the Kepler-51 planetary system?"
results = RAG.search(text_query, k=3)
print(results)

[{'doc_id': 0, 'page_num': 1, 'score': 23.875, 'metadata': {}, 'base64': None}, {'doc_id': 0, 'page_num': 8, 'score': 23.625, 'metadata': {}, 'base64': None}, {'doc_id': 0, 'page_num': 3, 'score': 22.625, 'metadata': {}, 'base64': None}]


#Prompting Qwen2-VL in a Multimodal RAG System

Next, we convert our PDF document into images. Then, using the results returned by Byaldi for our query, we can get the corresponding image:

In [None]:
images = convert_from_path("documents/1.pdf")
image_index = results[0]["page_num"] - 1

Then, the next step is a standard inference pipeline for Qwen2-VL:

In [None]:
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": images[image_index],
            },
            {"type": "text", "text": text_query},
        ],
    }
]



text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")


generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)




preprocessor_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

['The host star is a G-type star.']
