<a href="https://colab.research.google.com/github/rsrini7/Colabs/blob/main/Multimodal_RAG_ColPali_%2B_Byaldi_%2B_Vision_AI_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Unsloth, Byaldi, and poppler-utils
# Unsloth installation specific to Colab
!pip install "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers trl peft accelerate bitsandbytes gradio
!pip install -q byaldi
!sudo apt-get install -y poppler-utils # For PDF processing if you extend to PDFs (Byaldi needs it)


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-t4n19hlc/unsloth_4a7217980bd84f7ba8e78371ec27f8f3
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-t4n19hlc/unsloth_4a7217980bd84f7ba8e78371ec27f8f3
  Resolved https://github.com/unslothai/unsloth.git to commit 937f684d4377c465452a7723c8bb97f1ecd2a3d5
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting unsloth_zoo>=2025.5.7 (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading unsloth_zoo-2025.5.7-py3-none-any.whl.metadata (8.0 kB)
Collecting tyro (from unsloth@ git+https://github.com/unslothai/unsloth.git

Collecting xformers
  Downloading xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting gradio
  Downloading gradio-5.29.1-py3-none-any.whl.metadata (16 kB)
Downloading xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl (31.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.5/31.5 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gradio-5.29.1-py3-none-any.whl (54.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.1/54.1 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xformers, gradio
Successfully installed gradio-5.29.1 xformers-0.0.30
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m517.9/517.9 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.8/422.8 kB[0m [31m29.8 MB/s

In [1]:
# Combined Cells 2 through 10 (Corrected Again)

# --- Cell 2: Core Imports ---
import os
import torch
from unsloth import FastVisionModel
from byaldi import RAGMultiModalModel
import requests
import io
import base64
from PIL import Image as PIL_Image
import tqdm
import time
import numpy as np
from transformers import TextStreamer
from IPython.display import display
print("Cell 2: Core Imports - Complete")

# --- Cell 3: Load Unsloth FastVisionModel (for Q&A) - MODIFIED ---
print("\nCell 3: Loading Unsloth FastVisionModel...")
try:
    model, tokenizer = FastVisionModel.from_pretrained(
        "unsloth/Qwen2.5-VL-3B-Instruct",
        load_in_4bit = False,
    )
    FastVisionModel.for_inference(model)
    print("Unsloth Qwen2.5-VL model loaded and prepared for inference.")
except Exception as e:
    print(f"ERROR in Cell 3 (Load Unsloth Model): {e}")
    raise

# --- Cell 4: Load Byaldi RAG Model (for RAG Indexing Example) ---
print("\nCell 4: Loading Byaldi RAG Model...")
RAG_byaldi = None
try:
    # To save VRAM for the main Qwen model, you might want to comment out Byaldi loading for now
    # RAG_byaldi = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0", verbose=1)
    # print("Byaldi RAG model loaded successfully.")
    print("Skipping Byaldi RAG model loading to save VRAM for Qwen2.5-VL model.")
except Exception as e:
    print(f"Error loading Byaldi RAG model: {e}")
    print("Proceeding without Byaldi RAG model for indexing. The direct Q&A will still work.")
print("Cell 4: Load Byaldi RAG Model - Complete / Skipped")

# --- Cell 5: Download Example Images (for RAG Indexing Demonstration) ---
print("\nCell 5: Downloading Example Images...")
images_data = {
"tesla.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbef936e6-3efa-43b3-88d7-7ec620cdb33b_2744x1539.png",
    "netflix.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F23bd84c9-5b62-4526-b467-3088e27e4193_2744x1539.png",
    "nike.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fa5cd33ba-ae1a-42a8-a254-d85e690d9870_2741x1541.png",
    "google.png": "https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F395dd3b9-b38e-4d1f-91bc-d37b642ee920_2741x1541.png",
    "accenture.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F08b2227c-7dc8-49f7-b3c5-13cab5443ba6_2741x1541.png",
    "tecent.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0ec8448c-c4d1-4aab-a8e9-2ddebe0c95fd_2741x1541.png"
}
img_folder = "downloaded_images"
os.makedirs(img_folder, exist_ok=True)
img_paths = []
for name, url in tqdm.tqdm(images_data.items()):
    img_path = os.path.join(img_folder, name)
    img_paths.append(img_path)
    if not os.path.exists(img_path):
        try:
            response = requests.get(url, timeout=30)
            response.raise_for_status()
            with open(img_path, "wb") as fOut:
                fOut.write(response.content)
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {name}: {e}")
print("Cell 5: Image Download - Complete")

# --- Cell 6: Perform RAG Indexing (Using Byaldi) ---
print("\nCell 6: Performing RAG Indexing (if Byaldi model loaded)...")
if RAG_byaldi:
    try:
        RAG_byaldi.index(
            input_path=f"./{img_folder}/",
            index_name="image_attention_index",
            store_collection_with_index=True,
            overwrite=True
        )
        print("Image indexing complete.")
    except Exception as e:
        print(f"Error during RAG indexing: {e}")
else:
    print("Byaldi RAG model not loaded/skipped, skipping indexing.")
print("Cell 6: RAG Indexing - Complete")

# --- Cell 7: Define `search_rag` function ---
def search_rag(question_text):
    if not RAG_byaldi:
        print("Byaldi RAG model not loaded/skipped. Cannot perform RAG search.")
        return None
    results = RAG_byaldi.search(question_text, k=1)
    if results:
        retrieved_doc = results[0]
        if hasattr(retrieved_doc, 'base64') and retrieved_doc.base64:
            image_bytes_retrieved = base64.b64decode(retrieved_doc.base64)
            retrieved_image_pil = PIL_Image.open(io.BytesIO(image_bytes_retrieved))
            display(retrieved_image_pil)
            return retrieved_image_pil
        else:
            potential_path = os.path.join(img_folder, str(retrieved_doc.doc_id))
            if os.path.exists(potential_path):
                retrieved_image_pil = PIL_Image.open(potential_path)
                display(retrieved_image_pil)
                return retrieved_image_pil
            else:
                print(f"Could not display image for doc_id: {retrieved_doc.doc_id}.")
    else:
        print("No results found from RAG search.")
    return None

# --- Cell 8: Example of using `search_rag` function ---
print("\nCell 8: Example RAG Search...")
if RAG_byaldi:
    retrieved_image_from_rag = search_rag("Show me the Nike income statement")
    if retrieved_image_from_rag:
        print("Image retrieved and displayed via RAG search.")
else:
    print("Skipping RAG search example as Byaldi model was not loaded/skipped.")
print("Cell 8: RAG Search Example - Complete")

# --- Cell 9: Define `answer_with_llm` function - MODIFIED ---
def answer_with_llm(question_text, pil_image_input):
    if pil_image_input is None:
        return "Please upload an image first."
    prompt_template = f"""Answer the question based on the following image.
Don't use markdown.
Please provide enough context for your answer.
Question: {question_text}"""
    messages = [
        {"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": prompt_template}
        ]}
    ]
    inputs = tokenizer(
        images=pil_image_input,
        text=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
        add_special_tokens=False,
        return_tensors="pt",
    ).to(model.device)

    try:
        generated_ids = model.generate(**inputs, max_new_tokens=512, use_cache=True)
    except Exception as e_generate:
        print(f"Error during model.generate(): {e_generate}")
        return f"Error during generation: {e_generate}"

    input_token_len = inputs["input_ids"].shape[1]
    response_text = tokenizer.batch_decode(generated_ids[:, input_token_len:], skip_special_tokens=True)[0]
    return response_text.strip()

# --- Cell 10: Example of using `answer_with_llm` function ---
print("\nCell 10: Example LLM Q&A...")
try:
    nike_image_path = os.path.join(img_folder, "nike.png")
    if os.path.exists(nike_image_path):
        img_nike_pil = PIL_Image.open(nike_image_path)
        question = "What is the net profit for Nike in Q3 FY25 as shown in the image?"
        print(f"Question: {question}")
        llm_response = answer_with_llm(question, img_nike_pil)
        print(f"LLM Answer: {llm_response}")
    else:
        print(f"Nike image not found at {nike_image_path}. Please ensure Cell 5 (Download Images) ran correctly.")
except Exception as e:
    print(f"Error testing answer_with_llm: {e}")
print("Cell 10: LLM Q&A Example - Complete")

print("\n--- All Combined Cells (2-10) Executed ---")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.7.0+cu126 with CUDA 1206 (you have 2.6.0+cu124)
    Python  3.11.12 (you have 3.11.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!
Cell 2: Core Imports - Complete

Cell 3: Loading Unsloth FastVisionModel...
==((====))==  Unsloth 2025.5.5: Fast Qwen2_5_Vl patching. Transformers: 4.51.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Unsloth Qwen2.5-VL model loaded and prepared for inference.

Cell 4: Loading Byaldi RAG Model...
Skipping Byaldi RAG model loading to save VRAM for Qwen2.5-VL model.
Cell 4: Load Byaldi RAG Model - Complete / Skipped

Cell 5: Downloading Example Images...


100%|██████████| 6/6 [00:00<00:00, 36954.22it/s]

Cell 5: Image Download - Complete

Cell 6: Performing RAG Indexing (if Byaldi model loaded)...
Byaldi RAG model not loaded/skipped, skipping indexing.
Cell 6: RAG Indexing - Complete

Cell 8: Example RAG Search...
Skipping RAG search example as Byaldi model was not loaded/skipped.
Cell 8: RAG Search Example - Complete

Cell 10: Example LLM Q&A...
Question: What is the net profit for Nike in Q3 FY25 as shown in the image?





LLM Answer: The net profit for Nike in Q3 FY25, as shown in the image, is $0.8B.
Cell 10: LLM Q&A Example - Complete

--- All Combined Cells (2-10) Executed ---


In [3]:
!pip install gradio

Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.1 (from gradio)
  Downloading gradio_client-1.10.1-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)

In [4]:
import gradio as gr

def gradio_interface(image_upload, question_text_input):
    if image_upload is None:
       return "Please upload an image."
    if not question_text_input:
        return "Please enter a question."

    # Gradio provides the image as a NumPy array by default, convert to PIL
    pil_image = PIL_Image.fromarray(image_upload)

    response = answer_with_llm(question_text_input, pil_image)
    return response

# Create the Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"), # type="numpy" for easier conversion to PIL
        gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question")
    ],
    outputs=gr.Textbox(label="Answer"),
    title="Vision RAG: Multimodal Image Question Answering",
    description="Upload an image and ask any question about it. Powered by Qwen2.5-VL and Unsloth.",
    allow_flagging="never" # As in the video
)

# Launch the Gradio app
iface.launch(debug=True, share=True) # share=True provides a public link (expires)



Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://7a91118d28f6bbe7ed.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://7a91118d28f6bbe7ed.gradio.live


