## Requirements and setup

* Google Colab with access to a GPU.
* Data source (Provided).

In [None]:
import os
from google.colab import userdata
github_token = userdata.get('github')
!git clone https://{github_token}@github.com/starsofchance/Local_MultiMedia_RAG_For_Research.git

In [None]:
#cell 1.1
# Perform Google Colab installs (if running in Google Colab)
import os

if "COLAB_GPU" in os.environ:
    print("[INFO] Running in Google Colab, installing requirements.")
    !pip install -U torch # requires torch 2.1.1+ (for efficient sdpa implementation)
    !pip install PyMuPDF # for reading PDFs with Python
    !pip install tqdm # for progress bars
    !pip install sentence-transformers # for embedding models
    !pip install accelerate # for quantization model loading
    !pip install bitsandbytes # for quantizing models (less storage space)
    !pip install pillow # Added for image processing
    !pip install qwen-vl-utils

In [None]:
#cell 1.2
# for faster attention mechanism = faster LLM inference
# Make sure GPU runtime is enabled
!nvidia-smi

# Install torch matching Colab CUDA (usually CUDA 12.1)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Then install flash-attn pre-built for PyTorch+CUDA
!pip install flash-attn==2.8.3 --no-build-isolation


In [None]:
try:
    import flash_attn
    print(f"✅ flash-attn is installed and importable!")
    print(f"Version: {flash_attn.__version__}")
except ImportError as e:
    print(f"❌ Error importing flash-attn: {e}")

In [None]:
#cell 2
#import your own files if you did not clone the repo.
import os
import shutil
from google.colab import files

# Create folder
folder_name = "PDF_Files"
os.makedirs(folder_name, exist_ok=True)

# Upload local files
print("Select your PDF files from your computer:")
uploaded = files.upload()

# Move files into PDF_Files/
for filename in uploaded.keys():
    src = f"/content/{filename}"                       # Original auto-upload location
    dst = os.path.join(folder_name, filename)          # Desired final location
    shutil.move(src, dst)
    print(f"Moved: {dst}")

print("\nAll uploaded files are now ONLY inside the PDF_Files/ folder.")


In [None]:
# [REPLACEMENT] Cell 3: Robust Multimodal Extraction (Raster Images + Vector Charts)
# Run this to process ALL files in your PDF_Files folder correctly.

import fitz
import os
import numpy as np
from tqdm.auto import tqdm

# --- Helper Functions ---

def clean_block_text(text: str) -> str:
    """Cleans text within a specific block."""
    return " ".join(text.split())

def box_intersect_or_near(box1, box2, threshold=20):
    """Checks if two bounding boxes intersect or are close."""
    x_near = not (box1[0] > box2[2] + threshold or box2[0] > box1[2] + threshold)
    y_near = not (box1[1] > box2[3] + threshold or box2[1] > box1[3] + threshold)
    return x_near and y_near

def merge_boxes(boxes):
    """Merges a list of boxes into one large box."""
    if not boxes: return None
    x0 = min(b[0] for b in boxes)
    y0 = min(b[1] for b in boxes)
    x1 = max(b[2] for b in boxes)
    y1 = max(b[3] for b in boxes)
    return fitz.Rect(x0, y0, x1, y1)

def cluster_elements(rects, threshold=30):
    """
    Groups nearby vector elements (lines, shapes) into coherent figures.
    """
    if not rects: return []

    # Sort by vertical position to optimize clustering
    rects.sort(key=lambda r: r[1])

    clusters = []
    while rects:
        current = rects.pop(0)
        cluster = [current]

        i = 0
        while i < len(rects):
            candidate = rects[i]
            if box_intersect_or_near(current, candidate, threshold):
                cluster.append(candidate)
                current = merge_boxes([current, candidate])
                rects.pop(i)
            else:
                i += 1
        clusters.append(current)

    return clusters

def open_and_read_pdf(pdf_path: str, filename: str, image_output_dir: str) -> list[dict]:
    """
    Reads PDF, detects both Images (Pixels) and Drawings (Vector Charts),
    renders them, and extracts text.
    """
    doc = fitz.open(pdf_path)
    pages_and_texts = []

    # Ensure output directory exists (Fixes the FzErrorSystem error)
    os.makedirs(image_output_dir, exist_ok=True)

    for page_number, page in enumerate(doc):
        # 1. Text Extraction
        raw_blocks = page.get_text("blocks")
        page_text_content = []
        for block in raw_blocks:
            cleaned = clean_block_text(block[4])
            if cleaned:
                page_text_content.append(cleaned)
        full_page_text = "\n".join(page_text_content)

        # 2. Visual Extraction (Images + Drawings)
        valid_rects = []

        # A. Raster Images (Photos, Icons)
        img_info_list = page.get_image_info(xrefs=True)
        for img in img_info_list:
            if img['xref'] != 0:
                valid_rects.append(fitz.Rect(img['bbox']))

        # B. Vector Drawings (Charts, Graphs)
        drawings = page.get_drawings()
        for draw in drawings:
            r = draw["rect"]
            # Filter noise (dots, page borders)
            if r.width < 10 or r.height < 10: continue
            if r.width > page.rect.width * 0.95: continue
            valid_rects.append(r)

        # 3. Cluster and Render
        # Group scattered lines/images into single Diagram figures
        clustered_figures = cluster_elements(valid_rects, threshold=50)

        image_metadata = []
        for img_idx, fig_rect in enumerate(clustered_figures):
            # Ignore tiny clusters
            if fig_rect.width < 60 or fig_rect.height < 60:
                continue

            # Render the vector area to a high-res PNG
            pix = page.get_pixmap(clip=fig_rect, matrix=fitz.Matrix(3.0, 3.0))

            image_name = f"{os.path.splitext(filename)[0]}_p{page_number+1}_fig{img_idx}.png"
            image_save_path = os.path.join(image_output_dir, image_name)

            try:
                pix.save(image_save_path)
                image_metadata.append({
                    "image_path": image_save_path,
                    "bbox": [fig_rect.x0, fig_rect.y0, fig_rect.x1, fig_rect.y1],
                    "potential_caption": None # Can implement caption search if needed
                })
            except Exception as e:
                print(f"[WARN] Failed to save image {image_name}: {e}")

        pages_and_texts.append({
            "filename": filename,
            "page_number": page_number + 1,
            "text": full_page_text,
            "images": image_metadata,

            # Stats for downstream compatibility
            "page_char_count": len(full_page_text),
            "page_word_count": len(full_page_text.split()),
            "page_sentence_count_raw": len(full_page_text.split(". ")),
            "page_token_count": len(full_page_text)/4
        })

    return pages_and_texts

# --- MAIN EXECUTION FOR ALL FILES ---
pdf_folder = "PDF_Files"
image_output_folder = "Extracted_Images"
all_pages_and_texts = []

# List all PDFs
pdf_files = [f for f in os.listdir(pdf_folder) if f.lower().endswith(".pdf")]

print(f"[INFO] Processing {len(pdf_files)} PDF files (Vector + Raster extraction)...")

for filename in tqdm(pdf_files):
    pdf_path = os.path.join(pdf_folder, filename)
    file_data = open_and_read_pdf(pdf_path, filename, image_output_folder)
    all_pages_and_texts.extend(file_data)

print(f"\n[INFO] Processing complete.")
print(f"Total pages processed: {len(all_pages_and_texts)}")

In [None]:
!zip -r /content/Extracted_Images.zip /content/Extracted_Images

In [None]:
import random
import pandas as pd
from IPython.display import display, Image as IPImage, Markdown

# 1. Global Statistics
total_pages = len(all_pages_and_texts)
total_images = sum(len(p['images']) for p in all_pages_and_texts)
unique_files = set(p['filename'] for p in all_pages_and_texts)

print(f"--- Dataset Statistics ---")
print(f"Files Processed: {len(unique_files)}")
print(f"Total Pages:     {total_pages}")
print(f"Total Images:    {total_images}")
print(f"Avg Tokens/Page: {sum(p['page_token_count'] for p in all_pages_and_texts) / total_pages:.1f}")
print("-" * 30)

# 2. Find a "Rich" Sample (Page with Images)
# We specifically filter for a page that has images to verify our multimodal logic
pages_with_images = [p for p in all_pages_and_texts if len(p['images']) == 1]

if pages_with_images:
    sample = random.choice(pages_with_images)
    print(f"\n[INSPECTION] Metadata for: {sample['filename']} (Page {sample['page_number']})")

    # Show Text Snippet (First 500 chars)
    print(f"\n--- Text Snippet (First 500 chars) ---")
    print(sample['text'][:500] + "..." if len(sample['text']) > 500 else sample['text'])

    # Show Image Metadata & Render First Image
    print(f"\n--- Image Metadata ({len(sample['images'])} found) ---")
    for img in sample['images']:
        print(f"Path: {img['image_path']}")
        print(f"Caption Candidate: {img['potential_caption']}")
        print(f"BBox: {img['bbox']}")

        # Display the actual image in Colab
        display(IPImage(filename=img['image_path'], width=300))
        print("-" * 20)
else:
    print("\n[WARN] No images found in any processed pages.")

# 3. Pandas Overview (Optional - good for spotting outliers)
df = pd.DataFrame(all_pages_and_texts)
print("\n--- DataFrame Summary (Top 5 rows) ---")
display(df[["filename", "page_number", "page_token_count", "images"]].head())

In [None]:
# Cell 4: Smart Inspection - Focus on Graphs & Charts
import random
import pandas as pd
from IPython.display import display, Image as IPImage

# 1. Global Statistics
total_pages = len(all_pages_and_texts)
total_images = sum(len(p['images']) for p in all_pages_and_texts)
unique_files = set(p['filename'] for p in all_pages_and_texts)

print(f"--- Dataset Statistics ---")
print(f"Files Processed: {len(unique_files)}")
print(f"Total Pages:     {total_pages}")
print(f"Total Images:    {total_images}")
print("-" * 30)

# 2. Define "Interesting" Images (Heuristic: Area > 20,000 pixels)
# This filters out small author photos (~50x50) and keeps charts (~300x200+)
def get_image_area(bbox):
    width = bbox[2] - bbox[0]
    height = bbox[3] - bbox[1]
    return width * height

# Filter pages that have at least one "Large" image
rich_pages = []
for p in all_pages_and_texts:
    # Check if page has any image with area > 15,000
    if any(get_image_area(img['bbox']) > 15000 for img in p['images']):
        rich_pages.append(p)

if rich_pages:
    # Pick a random page from the "High Quality" list
    sample = random.choice(rich_pages)
    print(f"\n[INSPECTION] Metadata for: {sample['filename']} (Page {sample['page_number']})")

    # Show Text Snippet
    print(f"\n--- Text Snippet ---")
    print(sample['text'][:300] + "..." )

    print(f"\n--- Visuals Found on Page ({len(sample['images'])}) ---")

    # SORT images by size (Largest First) so we see the main Chart, not the footer logo
    sorted_images = sorted(sample['images'], key=lambda x: get_image_area(x['bbox']), reverse=True)

    for img in sorted_images:
        area = get_image_area(img['bbox'])

        # Only display if it's not tiny noise
        if area > 5000:
            print(f"Path: {img['image_path']}")
            print(f"Size Score: {area:.0f} (Likely a Chart/Figure)")
            display(IPImage(filename=img['image_path'], width=400))
            print("-" * 30)
        else:
            print(f"[Skipping small image/icon, size {area:.0f}]")
else:
    print("\n[WARN] No pages with large images found. Try lowering the area threshold.")

In [None]:
# [UPDATED] Cell 5: Create chunks and link images based on Page Association
import re
from tqdm.auto import tqdm

def split_text_into_chunks(text: str, chunk_size: int = 1000, overlap: int = 200) -> list[dict]:
    """
    Splits text into overlapping chunks using a simple sliding window.
    """
    if not text:
        return []

    chunks = []
    start = 0
    text_len = len(text)

    while start < text_len:
        end = start + chunk_size
        if end < text_len:
            lookback = text[end-50:end]
            last_space = lookback.rfind(" ")
            if last_space != -1:
                end = (end - 50) + last_space

        chunk_text = text[start:end].strip()
        if len(chunk_text) > 50:
            chunks.append({
                "text": chunk_text,
                "start_char_idx": start,
                "end_char_idx": end
            })
        start += (chunk_size - overlap)

    return chunks

def create_multimodal_chunks(pages_data: list[dict]) -> list[dict]:
    """
    Generates chunks and links images.
    Updated Strategy: Since vector chart captions are hard to extract perfectly,
    we associate ALL images on a page with ALL text chunks on that page.
    This ensures high retrieval recall (the model will always see the charts relevant to the text).
    """
    final_chunks = []

    print("[INFO] Starting chunking process (Page-Based Association)...")

    for page in tqdm(pages_data):
        # 1. Split the page text
        raw_chunks = split_text_into_chunks(page["text"])

        for chunk in raw_chunks:
            chunk_text = chunk["text"]
            attached_images = []

            # 2. Attach Images
            # Strategy: If an image exists on this page, link it to this chunk.
            # This is robust for Research Papers where charts are usually relevant to the whole page text.
            if page["images"]:
                for img in page["images"]:
                    attached_images.append(img["image_path"])

            # 3. Create Chunk Object with ALL Metadata
            final_chunks.append({
                "id": f"{page['filename']}_p{page['page_number']}_{chunk['start_char_idx']}",
                "filename": page["filename"],
                "page_number": page["page_number"],

                # The Content
                "text": chunk_text,
                "images": attached_images, # Link established!

                # Chunk-Specific Stats
                "chunk_char_count": len(chunk_text),
                "chunk_token_count": len(chunk_text) / 4,

                # Original Page Stats (Carried Over)
                "orig_page_char_count": page["page_char_count"],
                "orig_page_word_count": page["page_word_count"],
                "orig_page_sentence_count": page["page_sentence_count_raw"],
                "orig_page_token_count": page["page_token_count"]
            })

    print(f"[INFO] Created {len(final_chunks)} chunks from {len(pages_data)} pages.")
    return final_chunks

# Execute Chunking
chunks = create_multimodal_chunks(all_pages_and_texts)

# Verify the metadata link
chunks_with_images = [c for c in chunks if len(c['images']) > 0]
print(f"\n[INSPECTION] Total Chunks with Images: {len(chunks_with_images)}")

if chunks_with_images:
    print(f"First Linked Chunk Image Count: {len(chunks_with_images[0]['images'])}")

In [None]:
# [UPDATED] Cell 6: DataFrame Statistics & Inspection
import pandas as pd
from IPython.display import display

# Convert our list of chunk dictionaries to a pandas DataFrame
df = pd.DataFrame(chunks)

# 1. Inspect the columns
# Note: 'embedding' column will NOT be here yet (we generate it in the next step)
print(f"Dataframe Shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")

# 2. Display the first 5 rows
display_cols = [
    "filename",
    "page_number",
    "text",
    "images",
    "chunk_token_count",
    "orig_page_token_count"
]

print("\n--- Knowledge Base Preview ---")
display(df[display_cols].head().style.set_properties(subset=['text'], **{'text-align': 'left', 'white-space': 'pre-wrap'}))

# 3. Statistical Summary (Tokens & Images)
print("\n--- Statistical Summary ---")

# Add a temporary column to count images per chunk for analysis
df['image_count'] = df['images'].apply(len)

# We analyze:
# - chunk_token_count: Size of text inputs
# - image_count: Verification of image linking (Should be > 0 max)
stats_cols = ["chunk_token_count", "orig_page_token_count", "image_count"]
display(df[stats_cols].describe().round(2))

In [None]:
# Load the High-Performance Qwen3 Model
from sentence_transformers import SentenceTransformer
import torch
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. Load Model
# Qwen3-Embedding-0.6B:
# - Score: 64.7 (Excellent)
# - Size: ~1.1GB (Leaves room for LLM)
# - Context: 32k tokens
print(f"[INFO] Loading embedding model: Qwen/Qwen3-Embedding-0.6B on {device}...")

embedding_model = SentenceTransformer(
    "Qwen/Qwen3-Embedding-0.6B",
    trust_remote_code=True,
    device=device
)

# 2. Embed
# Note: For Qwen models, it is often best practice to add a task instruction for queries,
# but for *documents* (what we are doing now), we usually embed them raw.
print("[INFO] Generating embeddings...")
chunk_texts = [chunk["text"] for chunk in chunks]

# We increase batch_size slightly since this model is efficient
embeddings = embedding_model.encode(chunk_texts, batch_size=64, show_progress_bar=True)

# 3. Store
for i, chunk in enumerate(chunks):
    chunk["embedding"] = embeddings[i]

# 4. Save to Disk
df = pd.DataFrame(chunks)
save_path = "multimodal_rag_embeddings_qwen.pkl"
df.to_pickle(save_path)

print(f"\n[INFO] Embedding complete.")
print(f"Total Vectors: {len(chunks)}")
print(f"Vector Dimension: {len(chunks[0]['embedding'])}") # Expected: 1024 or 1536 for Qwen (usually larger than mpnet)
print(f"Saved to: {save_path}")

In [None]:
import pandas as pd
import torch
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

print("[INFO] Loading Knowledge Base from Pickle...")

# 1. Load the DataFrame (Pickle preserves data types, so no string parsing needed)
df = pd.read_pickle("multimodal_rag_embeddings_qwen.pkl")

# 2. Convert embeddings to a PyTorch Tensor
# Since we used Pickle, the 'embedding' column is already a list of arrays.
# We just stack them and convert to Tensor.
text_embeddings = torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to(device)

# 3. Convert metadata to a list of dictionaries (for easy access during retrieval)
pages_and_chunks = df.to_dict(orient="records")

print(f"[INFO] Loading complete.")
print(f"Device: {device}")
print(f"Embeddings Shape: {text_embeddings.shape}")
# Shape should be [Num_Chunks, Model_Dim] (e.g., [1416, 1024] for Qwen)

In [None]:
from sentence_transformers import SentenceTransformer, util
import torch
import textwrap
from IPython.display import display, Image

# 1. Load the same model used for the database
# We need this to convert your questions into vectors
print("[INFO] Loading Qwen embedding model for retrieval...")
embedding_model = SentenceTransformer(
    "Qwen/Qwen3-Embedding-0.6B",
    trust_remote_code=True,
    device=device
)

def retrieve_relevant_resources(query: str, embeddings: torch.Tensor, n_resources: int = 5):
    """
    Embeds the query and returns the top-k most similar chunks from the database.
    """
    # Embed the query
    query_embedding = embedding_model.encode(query, convert_to_tensor=True)

    # Calculate Similarity (Dot Product)
    # This compares the query vector to all 1416 chunk vectors at once
    dot_scores = util.dot_score(query_embedding, embeddings)[0]

    # Get top-k results
    scores, indices = torch.topk(dot_scores, k=n_resources)

    return scores, indices

def print_multimodal_results(query, scores, indices):
    """
    Prints text and DISPLAYS IMAGES for retrieved chunks.
    """
    print(f"\nQuery: '{query}'\n")
    print("-" * 50)

    for score, idx in zip(scores, indices):
        chunk_data = pages_and_chunks[idx.item()]

        print(f"Score: {score:.4f}")
        print(f"Source: {chunk_data['filename']} (Page {chunk_data['page_number']})")
        print("\nText:")
        print(textwrap.fill(chunk_data["text"], width=80))

        # --- THE MULTIMODAL PART ---
        # If this text chunk has an attached image, display it!
        if chunk_data["images"]:
            print(f"\n[Visual Context Found] Displaying {len(chunk_data['images'])} image(s):")
            for img_path in chunk_data["images"]:
                display(Image(filename=img_path, width=400))

        print("-" * 50)

# --- Test the Search ---
# Try a query relevant to your PDFs (Security/LLMs)
query = "in the paper called:Malware Detection at the Edge with Lightweight LLMs: A Performance Evaluation, the TON-IoT dataset includes what?"

scores, indices = retrieve_relevant_resources(query, text_embeddings)
print_multimodal_results(query, scores, indices)

In [None]:
query = "complete this paragraph:TON-IoT. The TON-IoT dataset, released in 2021 by Moustafa et al. [6], stands as one of"

scores, indices = retrieve_relevant_resources(query, text_embeddings)
print_multimodal_results(query, scores, indices)


In [None]:
# Cell 12: Memory Cleanup & LLM Loading
import gc
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
# 1. FORCE CLEANUP
# Delete the embedding model and any large intermediate variables
if 'embedding_model' in globals():
    del embedding_model
if 'embeddings' in globals(): # The raw array, we keep text_embeddings tensor
    del embeddings

gc.collect()
torch.cuda.empty_cache()

print(f"[INFO] Memory flushed.")
print(f"GPU Memory Free: {torch.cuda.mem_get_info()[0] / 1024**3:.2f} GB")

# 2. Load the LLM (Qwen2.5-7B-Instruct)
# We use 4-bit quantization to fit ~16GB of weights into roughly 5-6GB of VRAM
# [NEW] Cell 12: Load Qwen2-VL (Vision Language Model)


# 1. Quantization Config (Same as before, keeps it under 16GB)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("[INFO] Loading Qwen2-VL-7B-Instruct (Vision Model)...")

# 2. Load The Vision Model
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# 3. Load The Processor (Handles Images + Text)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

print("[INFO] Vision Model Loaded Successfully!")

In [None]:
# 1. Reload the Embedding Model (Required for Search)
from sentence_transformers import SentenceTransformer

# We check if it's already loaded to avoid reloading unnecessarily
if 'embedding_model' not in globals():
    print("[INFO] Reloading Qwen Embedding Model for Search...")
    embedding_model = SentenceTransformer(
        "Qwen/Qwen3-Embedding-0.6B",
        trust_remote_code=True,
        device=device
    )
    print("[INFO] Embedding Model Loaded.")
else:
    print("[INFO] Embedding Model already active.")



In [None]:
###################multimdeia model

In [None]:
# Cell 20: Multimodal RAG (Vision-Language Model with Controls)
# [UPDATED] Cell 20: Multimodal RAG with Semantic Image Filtering
from IPython.display import display, Image as IPImage
from sentence_transformers import util

def talk_to_rag(query, target_file=None, top_k=25, top_visuals=2, score_threshold=0.35, debug=False):
    """
    Multimodal RAG with Smart Image Filtering.

    New Parameter:
    - score_threshold: If an image's surrounding text doesn't match the query
                       above this score, the image is hidden. (Prevents irrelevant diagrams).
                       Default 0.35 is a good balance for Qwen embeddings.
    """

    context_text = ""
    images_to_process = []
    sources_used = []

    # --- MODE 1: TARGET SPECIFIC FILE (Now with Smart Image Selection) ---
    if target_file:
        print(f"[MODE] Reading full document: {target_file}")

        # 1. Get all chunks for this file
        file_indices = [i for i, c in enumerate(pages_and_chunks) if c['filename'] == target_file]

        if not file_indices:
            print(f"[ERROR] File '{target_file}' not found.")
            return

        # 2. Build Full Context (Text is still the WHOLE paper)
        # We sort by page number so the LLM reads the paper in order
        file_chunks = [pages_and_chunks[i] for i in file_indices]
        file_chunks.sort(key=lambda x: x['page_number'])
        full_text = "\n\n".join([c['text'] for c in file_chunks])
        context_text = f"DOCUMENT: {target_file}\nCONTENT:\n{full_text}"

        # 3. SMART IMAGE SELECTION
        # Instead of taking the first images blindly, we rank the file's chunks by relevance.
        if debug: print(f"[DEBUG] Ranking {len(file_indices)} chunks in file for visual relevance...")

        # Get embeddings just for this file
        target_embeddings = text_embeddings[file_indices]
        query_vec = embedding_model.encode(query, convert_to_tensor=True)

        # Calculate similarity within this file
        scores = util.dot_score(query_vec, target_embeddings)[0]

        # Pair up (Score, Chunk)
        scored_chunks = []
        for i, score in enumerate(scores):
            idx = file_indices[i] # Original global index
            scored_chunks.append((score.item(), pages_and_chunks[idx]))

        # Sort by relevance (Highest Score First)
        scored_chunks.sort(key=lambda x: x[0], reverse=True)

        # Pick images ONLY from the most relevant parts of the paper
        for score, chunk in scored_chunks:
            if chunk['images']:
                # CHECK THRESHOLD: Is this image actually related?
                if score < score_threshold:
                    if debug: print(f"[SKIP] Image on Page {chunk['page_number']} (Score {score:.4f} < {score_threshold})")
                    continue

                for img in chunk['images']:
                    if img not in images_to_process and len(images_to_process) < top_visuals:
                        if debug: print(f"[KEEP] Image on Page {chunk['page_number']} (Score {score:.4f})")
                        images_to_process.append(img)

    # --- MODE 2: SEARCH DATABASE (Standard RAG) ---
    else:
        print(f"[MODE] Searching database for: '{query}' (Deep Search: Top {top_k})")
        scores, indices = retrieve_relevant_resources(query, text_embeddings, n_resources=top_k)

        context_items = []
        if debug: print("\n[DEBUG] Rankings:")

        for rank, idx in enumerate(indices):
            chunk = pages_and_chunks[idx.item()]
            score = scores[rank].item()

            if debug: print(f"#{rank+1}: {chunk['filename']} (Score: {score:.4f})")

            context_items.append(chunk["text"])

            src = f"{chunk['filename']} (Page {chunk['page_number']})"
            if src not in sources_used: sources_used.append(src)

            # Smart Image Selection (Standard Mode)
            if chunk["images"]:
                # Check threshold here too (Consistency)
                if score < score_threshold:
                    if debug: print(f"[SKIP] Image on Page {chunk['page_number']} (Low Score)")
                    continue

                for img in chunk["images"]:
                    if img not in images_to_process and len(images_to_process) < top_visuals:
                        images_to_process.append(img)

        context_text = "\n\n---\n\n".join(context_items)

    # --- GENERATION ---
    if not images_to_process:
        print(f"[INFO] No relevant images found (Threshold: {score_threshold}). Generating text-only response...")
        content_payload = [] # Text only
    else:
        print(f"[INFO] Generating answer with {len(images_to_process)} visual inputs...")
        content_payload = []
        for img_path in images_to_process:
            content_payload.append({"type": "image", "image": img_path})

    # Add Text Prompt
    system_prompt = f"""You are an academic assistant.
    Answer the user's question using the provided CONTEXT.
    If images are provided, use them to support your answer.

    CONTEXT:
    {context_text[:30000]}
    """

    content_payload.append({"type": "text", "text": f"{system_prompt}\n\nUSER QUERY: {query}"})

    messages = [{"role": "user", "content": content_payload}]

    # Process
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, _ = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt"
    ).to("cuda")

    # Generate
    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    answer = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0]

    print("\n" + "="*50)
    print(f"ANSWER:\n{answer}")
    print("="*50)

    # Show user what images were analyzed
    if images_to_process:
        print("\n[Visuals Analyzed]")
        for img_path in images_to_process:
            print(f"File: {img_path}")
            display(IPImage(filename=img_path, width=400))

In [None]:
#target the whole dataset
query = "What does Fig1 in paper called Malware Detection at the Edge with Lightweight LLMs_ A Performance Evaluation shows?."

talk_to_rag(query, top_k=10, debug=True, top_visuals=2)



In [None]:
# target an specific paper
query = "the TON-IoT dataset is created rom what? what number of smaples is in it and what is it used for??"
talk_to_rag(query, top_k=10, debug=False, target_file="Malware Detection at the Edge with Lightweight LLMs_ A Performance Evaluation.pdf")


In [None]:
#AFter restarting the env:

In [None]:
import os
from google.colab import userdata
github_token = userdata.get('github')
!git clone https://{github_token}@github.com/starsofchance/Local_MultiMedia_RAG_For_Research.git

In [None]:
#cell 1.1
# Perform Google Colab installs (if running in Google Colab)
import os

if "COLAB_GPU" in os.environ:
    print("[INFO] Running in Google Colab, installing requirements.")
    !pip install -U torch # requires torch 2.1.1+ (for efficient sdpa implementation)
    !pip install PyMuPDF # for reading PDFs with Python
    !pip install tqdm # for progress bars
    !pip install sentence-transformers # for embedding models
    !pip install accelerate # for quantization model loading
    !pip install bitsandbytes # for quantizing models (less storage space)
    !pip install pillow # Added for image processing
    !pip install qwen-vl-utils
    #cell 1.2
    # for faster attention mechanism = faster LLM inference
    # Make sure GPU runtime is enabled
    # Install torch matching Colab CUDA (usually CUDA 12.1)
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
    # Then install flash-attn pre-built for PyTorch+CUDA
    !pip install flash-attn==2.8.3 --no-build-isolation

In [None]:
import os
import torch
import pandas as pd
import fitz # PyMuPDF
import numpy as np
import re
from tqdm.auto import tqdm
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
from IPython.display import display, Image as IPImage

# --- HELPER FUNCTIONS (Must be defined before the Class) ---
# (Paste the helper functions 'open_and_read_pdf', 'create_multimodal_chunks', etc. here if they aren't already in the cell)
# For brevity, I assume you have the helpers from the previous "Cell 2" block.

class MultimodalRAG:
    def __init__(self, db_path="multimodal_rag_embeddings_qwen.pkl",
                 pdf_folder="PDF_Files",
                 image_folder="Extracted_Images"):

        self.db_path = db_path
        self.pdf_folder = pdf_folder
        self.image_folder = image_folder
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # State
        self.pages_and_chunks = []
        self.text_embeddings = None

        # Models
        self.embedding_model = None
        self.vision_model = None
        self.processor = None

        # Auto-Load DB
        self.load_database()

    def load_database(self):
        """Loads the .pkl database into memory."""
        if os.path.exists(self.db_path):
            print(f"[INIT] Loading database from {self.db_path}...")
            df = pd.read_pickle(self.db_path)
            self.pages_and_chunks = df.to_dict(orient="records")

            # Convert list of arrays to Tensor
            self.text_embeddings = torch.tensor(np.stack(df["embedding"].tolist()), dtype=torch.float32).to(self.device)
            print(f"[SUCCESS] Loaded {len(self.pages_and_chunks)} chunks.")
        else:
            print(f"[INIT] Database not found. Starting fresh.")

    def load_models(self):
        """Loads Qwen Embedding (Search) and Qwen-VL (Vision Brain)."""
        print("[LOAD] Loading Embedding Model...")
        self.embedding_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B", trust_remote_code=True, device=self.device)

        print("[LOAD] Loading Vision Model...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
        )
        self.vision_model = Qwen2VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen2-VL-7B-Instruct", quantization_config=bnb_config,
            device_map="auto", trust_remote_code=True
        )
        self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
        print("[SUCCESS] Models Ready.")

    def ingest_new_files(self):
        """Scans PDF_Files, processes new PDFs, and updates the DB."""
        if self.embedding_model is None:
            print("[ERROR] Models not loaded. Run .load_models()"); return

        existing_files = set(c['filename'] for c in self.pages_and_chunks)
        if not os.path.exists(self.pdf_folder): os.makedirs(self.pdf_folder)
        disk_files = set(f for f in os.listdir(self.pdf_folder) if f.lower().endswith(".pdf"))
        new_files = list(disk_files - existing_files)

        if not new_files:
            print("[INFO] No new PDF files found."); return

        print(f"[UPDATE] Processing {len(new_files)} new files...")
        new_chunks_buffer = []

        for filename in tqdm(new_files):
            full_path = os.path.join(self.pdf_folder, filename)
            # Calls the global helper functions
            pages = open_and_read_pdf(full_path, filename, self.image_folder)
            chunks = create_multimodal_chunks(pages)

            texts = [c["text"] for c in chunks]
            vectors = self.embedding_model.encode(texts, batch_size=32, show_progress_bar=False)

            for i, chunk in enumerate(chunks): chunk["embedding"] = vectors[i]
            new_chunks_buffer.extend(chunks)

        if new_chunks_buffer:
            self.pages_and_chunks.extend(new_chunks_buffer)
            new_tensor = torch.tensor(np.stack([c['embedding'] for c in new_chunks_buffer]), dtype=torch.float32).to(self.device)

            if self.text_embeddings is None: self.text_embeddings = new_tensor
            else: self.text_embeddings = torch.cat((self.text_embeddings, new_tensor), dim=0)

            pd.DataFrame(self.pages_and_chunks).to_pickle(self.db_path)
            print(f"[SUCCESS] Database updated. Total Chunks: {len(self.pages_and_chunks)}")

    def chat(self, query, target_file=None, top_k=25, top_visuals=2, score_threshold=0.35, debug=False):
        """
        Main RAG Interface.
        - target_file: If set, restricts context to one paper.
        - top_visuals: Max images to show the model.
        - score_threshold: Minimum similarity score (0.0-1.0) to consider an image relevant.
        """
        if self.vision_model is None: print("[ERROR] Models not loaded."); return

        context_items = []
        images_to_process = []
        sources = []

        # Embed Query
        query_vec = self.embedding_model.encode(query, convert_to_tensor=True)

        # --- MODE 1: Talk to Specific Paper ---
        if target_file:
            print(f"[MODE] Reading document: {target_file}")
            # Find indices for this file
            file_indices = [i for i, c in enumerate(self.pages_and_chunks) if c['filename'] == target_file]

            if not file_indices: print(f"[ERROR] File not found."); return

            # Get subset of embeddings for this file to check relevance
            target_embeddings = self.text_embeddings[file_indices]
            scores = util.dot_score(query_vec, target_embeddings)[0]

            # Sort chunks by relevance
            scored_chunks = []
            for i, score in enumerate(scores):
                idx = file_indices[i]
                scored_chunks.append((score.item(), self.pages_and_chunks[idx]))

            # Sort chunks by Page Number for reading flow (Text)
            # But we use the Scores for Image Filtering
            file_chunks_sorted = sorted([c for _, c in scored_chunks], key=lambda x: x['page_number'])

            # Build Text Context (Whole Paper)
            context_items = [c['text'] for c in file_chunks_sorted]
            sources.append(f"Full Document: {target_file}")

            # Select Images (Based on Score)
            scored_chunks.sort(key=lambda x: x[0], reverse=True) # Sort by score for images

            for score, chunk in scored_chunks:
                if chunk['images']:
                    if score < score_threshold:
                        if debug: print(f"[SKIP] Img on Page {chunk['page_number']} (Score {score:.2f} < {score_threshold})")
                        continue

                    for img in chunk['images']:
                        if img not in images_to_process and len(images_to_process) < top_visuals:
                            images_to_process.append(img)

        # --- MODE 2: Standard Search ---
        else:
            print(f"[MODE] Searching DB (Top {top_k})...")
            scores = util.dot_score(query_vec, self.text_embeddings)[0]
            top_scores, top_indices = torch.topk(scores, k=top_k)

            for rank, idx in enumerate(top_indices):
                chunk = self.pages_and_chunks[idx.item()]
                score = top_scores[rank].item()

                if debug: print(f"#{rank+1}: {chunk['filename']} ({score:.4f})")

                context_items.append(chunk["text"])
                src = f"{chunk['filename']} (Page {chunk['page_number']})"
                if src not in sources: sources.append(src)

                # Filter Images
                if chunk["images"]:
                    if score < score_threshold: continue
                    for img in chunk["images"]:
                        if img not in images_to_process and len(images_to_process) < top_visuals:
                            images_to_process.append(img)

        # --- GENERATION ---
        msg = f"[AI] Thinking with {len(images_to_process)} images..."
        if len(images_to_process) == 0 and score_threshold > 0: msg += " (others filtered by threshold)"
        print(msg)

        content_payload = []
        for img in images_to_process:
            content_payload.append({"type": "image", "image": img})

        sys_prompt = f"""You are an academic assistant.
        Answer the user's question using the provided CONTEXT.
        If images are provided, use them to support your answer.
        CONTEXT: {' '.join(context_items)[:32000]}"""

        content_payload.append({"type": "text", "text": f"{sys_prompt}\n\nQUERY: {query}"})

        messages = [{"role": "user", "content": content_payload}]
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(self.device)

        gen_ids = self.vision_model.generate(**inputs, max_new_tokens=1024)
        answer = self.processor.batch_decode(gen_ids, skip_special_tokens=True)[0].split("<|im_start|>assistant\n")[-1].strip()

        print("\n" + "="*40 + f"\nANSWER:\n{answer}\n" + "="*40)

        if images_to_process:
            print("\n[Visuals Analyzed]")
            for img in images_to_process: display(IPImage(filename=img, width=300))

In [None]:
# Cell 3: Initialize (The Start Button)

# Run this to wake everything up.
rag = MultimodalRAG()
rag.load_models()

Mode 1: Search the Whole Database (Default)

Use this when you want to compare papers or find answers from "somewhere" in your data.
Python

# Query only
rag.chat("What is the difference between data poisoning and prompt injection?")

Mode 2: Talk to One Specific Paper

Use this when you want to summarize a file or ask about a specific figure in a specific file. Note: You must copy the filename exactly (including .pdf).
Python

# Query + target_file
rag.chat("Summarize the methodology of this paper", target_file="My_Paper.pdf")

Bonus: Search + Visual Control

You can combine search with the visual limit (e.g., look at 4 charts).
Python

rag.chat("Compare the accuracy charts", top_visuals=4)

In [None]:

# Talk
rag.chat("what is the goal in Prompt injection attacks on vision language models in oncology? what does the attacers trying to achive?")

In [None]:
rag.chat("What is the difference between data poisoning and prompt injection?")

In [None]:
rag.chat("explain the METHODOLOGY in the paper",
         target_file="BadPre Task-agnostic Backdoor Attacks to Pre-trained NLP Foundation Models.pdf",
         top_k=10)