In [None]:
# %pip install langchain-text-splitters langchain-community sentence-transformers pypdf langchain-qdrant

In [None]:
# %pip install pdf2image qdrant-client pillow

In [None]:
# %pip install pymupdf pillow

In [None]:
!nvidia-smi

In [None]:
# 0. Imports
import os
from typing import List, Tuple, Any, Dict

from pdf2image import convert_from_path
from PIL import Image

import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, pipeline

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document

from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

import fitz

In [None]:
# Load & Split PDF + Store with page metadata
PDF_PATH = "Dataset/QwenVL2.5.pdf"

loader = PyPDFLoader(PDF_PATH)
documents: List[Document] = loader.load()

# Add page number to metadata if not present
for i, doc in enumerate(documents):
    if "page" not in doc.metadata:
        doc.metadata["page"] = i  # 0-based page index

splitter = RecursiveCharacterTextSplitter(chunk_size=1600, chunk_overlap=50)
docs = splitter.split_documents(documents)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# In-memory Qdrant (you can switch to :memory: or remote URL)
client = QdrantClient(":memory:")

COLLECTION_NAME = "pdf_rag"
if not client.collection_exists(COLLECTION_NAME):
    client.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=384, distance=Distance.COSINE),
    )

vector_store = QdrantVectorStore(
    client=client,
    collection_name=COLLECTION_NAME,
    embedding=embeddings,
)

# ADD DOCUMENTS WITH PAGE METADATA
# LangChain-Qdrant automatically stores metadata
vector_store.add_documents(docs)

# Create retriever (MMR for diversity)
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 5})

In [None]:
model_name = "Qwen/Qwen3-VL-4B-Instruct"

In [None]:
# Load Qwen3-VL-4B-Instruct (GPU if available)
def load_qwen_vl() -> Any:
    try:
        device = 0 if torch.cuda.is_available() else -1

        model = Qwen3VLForConditionalGeneration.from_pretrained(
            "Qwen/Qwen3-VL-4B-Instruct", 
            dtype=torch.bfloat16, 
            device_map="auto")
        
        processor = AutoProcessor.from_pretrained(model_name)
               
        def qwen_vl(messages: List[Dict]) -> str:
            nonlocal model  # Reference outer scope's model
            
            # Extract images and build proper message format
            images = []
            messages_processed = []
            
            for msg in messages:
                new_msg = {"role": msg["role"], "content": []}
                if isinstance(msg.get("content"), list):
                    for item in msg["content"]:
                        if item.get("type") == "image":
                            images.append(item["image"])
                            new_msg["content"].append({"type": "image"})
                        elif item.get("type") == "text":
                            new_msg["content"].append({"type": "text", "text": item["text"]})
                messages_processed.append(new_msg)
            
            # Use processor's apply_chat_template
            text = processor.apply_chat_template(messages_processed, tokenize=False, add_generation_prompt=True)
            
            # Process with images and text
            try:
                inputs = processor(text=text, images=images if images else None, return_tensors="pt")
            except Exception as e:
                print(f"[Processor error] {e}. Trying text-only mode...")
                text_only = text.replace("<image>", "").replace("<|image_|", "").replace("|>", "")
                inputs = processor(text=text_only, return_tensors="pt")
            
            # Move ALL tensors to correct device
            if device == 0:
                for key in inputs.keys():
                    if isinstance(inputs[key], torch.Tensor):
                        inputs[key] = inputs[key].to("cuda")
                if next(model.parameters()).device.type != 'cuda':
                    print("[WARNING] Model not on CUDA, moving now...")
                    model = model.to("cuda")

            generated_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=False,
            )
            generated_ids = generated_ids[:, inputs["input_ids"].shape[1]:]

            response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            return response.strip()

        return qwen_vl

    except Exception as e:
        print(f"[Qwen-VL] Load error: {e}")
        # Fallback tiny text model
        fallback = pipeline(
            "text2text-generation",
            model="google/flan-t5-small",
            max_length=256,
            device=0 if torch.cuda.is_available() else -1,
        )

        def fallback_vlm(messages: List[Dict]) -> str:
            txt = " ".join(
                c["text"] for m in messages for c in m["content"] if c["type"] == "text"
            )
            out = fallback(txt, do_sample=False)
            return out[0]["generated_text"] if isinstance(out, list) else str(out)

        return fallback_vlm


vlm_pipeline = load_qwen_vl()

In [None]:
# Helpers
def extract_text_and_page(item) -> Tuple[str, int]:
    """Works with LangChain Document or raw dict from Qdrant."""
    if isinstance(item, Document):
        text = item.page_content
        meta = item.metadata
    elif isinstance(item, dict):
        text = item.get("page_content", str(item))
        meta = item.get("metadata", {})
    else:
        text = str(item)
        meta = {}

    page = meta.get("page", -1)
    if isinstance(page, str):
        try:
            page = int(page)
        except ValueError:
            page = -1
    return text.strip(), page


def get_page_images(pdf_path: str, page_numbers: List[int], max_pages: int = 5) -> List[Image.Image]:
    if not os.path.exists(pdf_path):
        print(f"[ERROR] PDF not found: {pdf_path}")
        return []

    unique_pages = sorted({p for p in page_numbers if p >= 0})[:max_pages]
    if not unique_pages:
        return []

    try:
        doc = fitz.open(pdf_path)
        images = []

        for page_num in page_numbers:
            if page_num < 0 or page_num >= len(doc):
                continue
            page = doc[page_num]
            pix = page.get_pixmap(dpi=100)  # REDUCED DPI from 200 to save memory
            img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
            images.append(img)

        doc.close()
        return images

    except Exception as e:
        print(f"[PDF→Image] Error with PyMuPDF: {e}")
        return []

In [None]:
# RAG + Multimodal Answer Function
def rag_answer(query: str, task: str = None, k: int = 5, pdf_path: str = PDF_PATH, max_page_images: int = 2) -> str:
    if not os.path.exists(pdf_path):
        return "Error: PDF file not found."

    # Retrieve
    docs = None
    try:
        if hasattr(retriever, "invoke"):
            docs = retriever.invoke(query)
        elif hasattr(retriever, "get_relevant_documents"):
            docs = retriever.get_relevant_documents(query)
    except Exception:
        docs = None

    if docs is None:
        try:
            docs = vector_store.similarity_search(query, k=k)
        except Exception:
            docs = None

    if not docs:
        return "No documents retrieved."

    # Extract text + pages
    text_page_pairs = [extract_text_and_page(d) for d in docs]
    context_parts = [tp[0] for tp in text_page_pairs if tp[0]]
    page_numbers = [tp[1] for tp in text_page_pairs if tp[1] >= 0]

    if not context_parts:
        return "No usable text."

    context = "\n---\n".join(context_parts)

    # Get images (with reduced max_page_images)
    page_images = get_page_images(pdf_path, page_numbers, max_pages=max_page_images)

    # Task detection
    if task is None:
        q = query.lower()
        if any(w in q for w in ["summarize", "summary", "tl;dr"]):
            task = "summarize"
        elif any(w in q for w in ["ocr", "extract text", "read text"]):
            task = "ocr"
        elif any(w in q for w in ["what is", "describe", "what do you see", "vqa"]):
            task = "vqa"
        elif any(w in q for w in ["where is", "locate", "ground", "bbox"]):
            task = "ground"
        else:
            task = "answer"

    # Build Qwen-VL messages
    messages: List[Dict] = []

    if page_images:
        img_contents = [{"type": "image", "image": img} for img in page_images]
        messages.append({
            "role": "user",
            "content": img_contents + [{"type": "text", "text": f"Context (PDF pages {[p+1 for p in page_numbers]}):\n{context}\n\n"}]
        })
    else:
        messages.append({
            "role": "user",
            "content": [{"type": "text", "text": f"Context (text only):\n{context}\n\n"}]
        })

    # Task-specific instruction
    if task == "summarize":
        instruction = "Summarize the document within 200 words"
    elif task == "ocr":
        instruction = "Extract **all readable text** from the provided page images. Return only plain text."
    elif task == "vqa":
        instruction = query
    elif task == "ground":
        instruction = (f"Answer the question and, if you refer to objects in the images, "
            f"return bounding boxes in format [x1,y1,x2,y2] (normalized 0–1).\nQuestion: {query}")
    else:
        instruction = (f"Use the provided text context and images to answer concisely.\n"
            f"If the answer cannot be found, say *I don't know*.\nQuestion: {query}")

    messages[-1]["content"].append({"type": "text", "text": instruction})
    messages.append({"role": "assistant", "content": []})

    # Call VL model
    try:
        return vlm_pipeline(messages)
    except Exception as e:
        return f"Model error: {str(e)}"

In [None]:
# TEST IT!
if __name__ == "__main__":
    print("Testing retrieval...")
    print(retriever.invoke("Vision Encoder"))

    print("\n" + "="*60)
    print("RAG + Vision Answer:")
    answer = rag_answer(
        query="Summarise Vision Encoder",
        task="summarize",
        k=5,
        pdf_path=PDF_PATH,
        max_page_images=1
    )
    print(answer)