In [1]:
"""
Complete Local RAG Chatbot with Image Understanding
===================================================

✅ No Cloud Dependencies (runs 100% locally)
✅ No RAGatouille (direct Jina ColBERT v2 implementation)
✅ PyMuPDF4LLM for PDF conversion
✅ Image extraction and analysis with LLaVA vision model
✅ Hybrid retrieval (BM25s + Jina ColBERT v2 + RRF + Reranking)
✅ Markdown-aware semantic chunking
✅ SQLite database for storage

Requirements:
- Ollama (for LLMs: llama3.2:3b, llava:7b)
- Mac Mini M4 or similar (16GB RAM recommended)
"""



In [2]:
import os
# Suppress tokenizers parallelism warning when forking
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import re
import io
import time
import warnings
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

# Suppress deprecation warnings from transformers/sentence-transformers
warnings.filterwarnings('ignore', message='.*torch_dtype.*deprecated.*')

# Core libraries
import numpy as np
import torch
from PIL import Image as PILImage  # Renamed to avoid conflict with database model

# PDF and text processing
import pymupdf4llm
import fitz  # PyMuPDF for image extraction
from transformers import AutoTokenizer

# Retrieval
import bm25s
from bm25s.hf import BM25HF
import Stemmer  # PyStemmer for stemming
from sentence_transformers import SentenceTransformer

# Database
import sqlalchemy
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import DeclarativeBase

# LLM
import requests  # For Ollama API

In [3]:
# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass
class RAGConfig:
    """Configuration for local RAG system"""
    # Base directory (set to project root - parent of notebooks folder)
    base_dir: str = os.path.abspath(os.path.join(os.getcwd(), '..'))
    
    # Database
    db_path: str = None
    
    # Chunking
    min_chunk_size: int = 256
    max_chunk_size: int = 512  # Reduced to match model's max_seq_length
    chunk_overlap: int = 128
    
    # Retrieval
    bm25_top_k: int = 100
    colbert_top_k: int = 100
    final_top_k: int = 15  # Increased from 10 to 15 for better coverage and reduced hallucination
    
    # Models
    chat_model: str = "llama3.2:3b"
    vision_model: str = "gemma3:4b"
    embedding_model: str = "jinaai/jina-colbert-v2"
    
    # Ollama
    ollama_url: str = "http://localhost:11434"
    ollama_timeout: int = 300  # Increased timeout for slower models
    
    # Paths (will be set to absolute paths in __post_init__)
    bm25_index_path: str = None
    colbert_index_path: str = None
    images_dir: str = None
    
    # Device
    device: str = "mps" if torch.backends.mps.is_available() else "cpu"
    
    def __post_init__(self):
        """Set absolute paths after initialization"""
        if self.db_path is None:
            self.db_path = os.path.join(self.base_dir, "rag_local.db")
        if self.bm25_index_path is None:
            self.bm25_index_path = os.path.join(self.base_dir, "indexes", "bm25s")
        if self.colbert_index_path is None:
            self.colbert_index_path = os.path.join(self.base_dir, "indexes", "colbert")
        if self.images_dir is None:
            self.images_dir = os.path.join(self.base_dir, "extracted_images")

In [4]:
# ============================================================================
# DATABASE MODELS
# ============================================================================

class Base(DeclarativeBase):
    pass

class Document(Base):
    __tablename__ = 'documents'
    
    id = Column(Integer, primary_key=True)
    filename = Column(String(255), nullable=False)
    upload_date = Column(DateTime, default=datetime.utcnow)
    total_pages = Column(Integer)
    status = Column(String(50))

class Image(Base):
    __tablename__ = 'images'
    
    id = Column(Integer, primary_key=True)
    document_id = Column(Integer, nullable=False)
    page_number = Column(Integer, nullable=False)
    image_path = Column(String(500), nullable=False)
    description = Column(Text)
    image_type = Column(String(50))
    ocr_text = Column(Text)

class Chunk(Base):
    __tablename__ = 'chunks'
    
    id = Column(Integer, primary_key=True)
    document_id = Column(Integer, nullable=False)
    chunk_index = Column(Integer, nullable=False)
    text = Column(Text, nullable=False)
    heading_path = Column(String(500))
    token_count = Column(Integer)
    has_images = Column(Boolean, default=False)
    chunk_metadata = Column(Text)

In [5]:
# ============================================================================
# OLLAMA CLIENT WITH STREAMING SUPPORT
# ============================================================================

class OllamaClient:
    """Client for interacting with Ollama API with streaming support"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.base_url = config.ollama_url
    
    def generate(
        self, 
        model: str, 
        prompt: str, 
        system: str = "",
        images: List[str] = None,
        timeout: int = 300,
        stream: bool = False
    ) -> str:
        """Generate text with Ollama (with optional streaming)"""
        url = f"{self.base_url}/api/generate"
        
        payload = {
            "model": model,
            "prompt": prompt,
            "stream": stream
        }
        
        if system:
            payload["system"] = system
        
        if images:
            payload["images"] = images
        
        try:
            if stream:
                # Streaming mode - print tokens as they arrive
                response = requests.post(url, json=payload, timeout=timeout, stream=True)
                response.raise_for_status()
                
                full_response = ""
                for line in response.iter_lines():
                    if line:
                        chunk = json.loads(line)
                        if "response" in chunk:
                            token = chunk["response"]
                            print(token, end='', flush=True)
                            full_response += token
                        
                        # Check if done
                        if chunk.get("done", False):
                            break
                
                print()  # Newline after streaming
                return full_response
            else:
                # Non-streaming mode - wait for complete response
                response = requests.post(url, json=payload, timeout=timeout)
                response.raise_for_status()
                return response.json()["response"]
                
        except requests.exceptions.Timeout:
            print(f"\n❌ Ollama timeout after {timeout}s - model may be too slow or stuck")
            return ""
        except Exception as e:
            print(f"\n❌ Ollama error: {e}")
            if hasattr(e, 'response') and e.response is not None:
                print(f"Response content: {e.response.text}")
            return ""
    
    def analyze_image(self, image_path: str) -> Dict[str, str]:
        """Analyze image using Gemma3 multimodal model with enhanced OCR extraction"""
        if not os.path.exists(image_path):
            print(f"❌ Image not found: {image_path}")
            return {
                'description': 'Image not found',
                'type': 'error',
                'ocr_text': ''
            }
            
        # Read image and convert to base64
        try:
            with open(image_path, "rb") as f:
                import base64
                image_data = base64.b64encode(f.read()).decode('utf-8')
                
            # Enhanced prompt for better OCR and context extraction
            description_prompt = """Analyze this image carefully and provide detailed information:

1. TYPE: Classify this image (diagram, flowchart, chart, graph, table, screenshot, architecture diagram, code snippet, formula, etc.)

2. DESCRIPTION: Describe what the image shows in 2-3 detailed sentences. Include:
   - Main subject/purpose
   - Key components or elements
   - Relationships between elements (if applicable)
   - Colors, arrows, or visual indicators (if relevant)

3. TEXT: Extract ALL visible text from the image. This is CRITICAL for search accuracy.
   - Include labels, titles, legends, annotations
   - Include numbers, percentages, values
   - Include code, formulas, equations
   - Include any text in tables, boxes, or speech bubbles
   - Preserve the order and structure where possible
   - If no text is visible, write "No text visible"

Format your response EXACTLY as follows:
TYPE: [type]
DESCRIPTION: [description]
TEXT: [all extracted text]"""
            
            response = self.generate(
                model="gemma3:4b",  # Using Gemma3 for image analysis
                prompt=description_prompt,
                images=[image_data],
                timeout=120,  # Increased timeout for image analysis
                stream=False  # Don't stream for image analysis
            )
            
            # Parse response with more robust parsing
            result = {
                'description': 'No description generated',
                'type': 'unknown',
                'ocr_text': ''
            }
            
            if response:
                # Try to extract sections using regex
                import re
                
                # Try to find TYPE
                type_match = re.search(r'TYPE:\s*(.+)', response, re.IGNORECASE)
                if type_match:
                    result['type'] = type_match.group(1).strip().lower()
                
                # Try to find DESCRIPTION
                desc_match = re.search(r'DESCRIPTION:([\s\S]*?)(?=TEXT:|$)', response, re.IGNORECASE)
                if desc_match:
                    result['description'] = desc_match.group(1).strip()
                
                # Try to find TEXT
                text_match = re.search(r'TEXT:([\s\S]*)', response, re.IGNORECASE)
                if text_match:
                    ocr = text_match.group(1).strip()
                    # Don't store if it's just the "no text" message
                    if ocr.lower() != "no text visible":
                        result['ocr_text'] = ocr
            
            return result
            
        except Exception as e:
            print(f"❌ Error analyzing image {image_path}: {str(e)}")
            return {
                'description': f'Error analyzing image: {str(e)}',
                'type': 'error',
                'ocr_text': ''
            }
    
    def chat(
        self, 
        messages: List[Dict[str, str]], 
        context: str = None,
        stream: bool = True  # Enable streaming by default!
    ) -> str:
        """Chat with context - with EXTREMELY strong anti-hallucination instructions and streaming"""
        
        # Build system message with EXTREMELY STRONG anti-hallucination instructions
        if context:
            system_msg = """You are a document question-answering assistant. Follow these rules with ABSOLUTE strictness:

!! CRITICAL RULES - NO EXCEPTIONS !!

1. You MUST ONLY use information explicitly stated in the context below
2. DO NOT use any knowledge outside the provided context
3. DO NOT make inferences, assumptions, or educated guesses
4. DO NOT mention products, services, or technologies not explicitly in the context
5. If information is NOT in the context, respond EXACTLY: "I don't have that information in the provided documents"
6. DO NOT provide links, URLs, or suggest where to find more information
7. DO NOT say things like "for the latest information" or "check the official website"
8. When answering, cite the specific source number (e.g., "According to Source 2...")

CONTEXT FROM DOCUMENTS:
""" + context + """

Remember: If it's not in the context above, you DON'T KNOW IT. Period."""
        else:
            system_msg = "You are a helpful AI assistant. Please provide accurate and helpful responses based only on what you know."
        
        # Build prompt from messages
        prompt = "\n".join([
            f"{msg['role']}: {msg['content']}" 
            for msg in messages
        ])
        
        return self.generate(
            model=self.config.chat_model,
            prompt=prompt,
            system=system_msg,
            timeout=self.config.ollama_timeout,
            stream=stream  # Pass streaming flag
        )

In [6]:
# ============================================================================
# MARKDOWN-AWARE SEMANTIC CHUNKER
# ============================================================================

class MarkdownSemanticChunker:
    """Intelligent markdown chunking that respects document structure"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    def chunk_markdown(self, markdown_text: str, doc_context: str = "") -> List[Dict]:
        """Create semantically meaningful chunks"""
        sections = self._parse_markdown_hierarchy(markdown_text)
        chunks = self._create_chunks_from_sections(sections, doc_context)
        optimized_chunks = self._optimize_chunks(chunks)
        return optimized_chunks
    
    def _parse_markdown_hierarchy(self, text: str) -> List[Dict]:
        """Parse markdown into hierarchical sections"""
        lines = text.split('\n')
        sections = []
        current_section = None
        heading_stack = []
        
        for line in lines:
            heading_match = re.match(r'^(#{1,6})\s+(.+)$', line)
            
            if heading_match:
                if current_section:
                    sections.append(current_section)
                
                level = len(heading_match.group(1))
                title = heading_match.group(2).strip()
                
                heading_stack = [(lvl, ttl) for lvl, ttl in heading_stack if lvl < level]
                heading_stack.append((level, title))
                
                parent_path = ' > '.join([ttl for _, ttl in heading_stack[:-1]])
                full_path = ' > '.join([ttl for _, ttl in heading_stack])
                
                current_section = {
                    'level': level,
                    'title': title,
                    'content': '',
                    'parent_path': parent_path,
                    'full_path': full_path
                }
            else:
                if current_section is not None:
                    current_section['content'] += line + '\n'
                else:
                    if not sections or sections[-1]['level'] != 0:
                        sections.append({
                            'level': 0,
                            'title': 'Introduction',
                            'content': line + '\n',
                            'parent_path': '',
                            'full_path': 'Introduction'
                        })
                    else:
                        sections[-1]['content'] += line + '\n'
        
        if current_section:
            sections.append(current_section)
        
        return sections
    
    def _create_chunks_from_sections(self, sections: List[Dict], doc_context: str) -> List[Dict]:
        """Create chunks from sections"""
        chunks = []
        current_chunk = None
        
        for section in sections:
            section_text = self._format_section_text(section)
            section_tokens = self._count_tokens(section_text)
            
            if section_tokens > self.config.max_chunk_size:
                if current_chunk:
                    chunks.append(current_chunk)
                    current_chunk = None
                
                split_chunks = self._split_large_section(section, doc_context)
                chunks.extend(split_chunks)
            
            elif section_tokens >= self.config.min_chunk_size:
                if current_chunk:
                    chunks.append(current_chunk)
                    current_chunk = None
                
                chunks.append({
                    'text': section_text,
                    'heading_path': section['full_path'],
                    'level': section['level'],
                    'token_count': section_tokens,
                    'doc_context': doc_context,
                    'type': 'section'
                })
            
            else:
                if current_chunk is None:
                    current_chunk = {
                        'text': section_text,
                        'heading_path': section['parent_path'] or section['title'],
                        'level': section['level'],
                        'token_count': section_tokens,
                        'doc_context': doc_context,
                        'type': 'accumulated',
                        'sections': [section['title']]
                    }
                else:
                    combined_text = current_chunk['text'] + '\n\n' + section_text
                    combined_tokens = self._count_tokens(combined_text)
                    
                    if combined_tokens <= self.config.max_chunk_size:
                        current_chunk['text'] = combined_text
                        current_chunk['token_count'] = combined_tokens
                        current_chunk['sections'].append(section['title'])
                    else:
                        chunks.append(current_chunk)
                        current_chunk = {
                            'text': section_text,
                            'heading_path': section['parent_path'] or section['title'],
                            'level': section['level'],
                            'token_count': section_tokens,
                            'doc_context': doc_context,
                            'type': 'accumulated',
                            'sections': [section['title']]
                        }
        
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks
    
    def _split_large_section(self, section: Dict, doc_context: str) -> List[Dict]:
        """Split large section at paragraph boundaries"""
        heading_text = f"# {section['title']}\n\n"
        parent_context = f"Context: {section['parent_path']}\n\n" if section['parent_path'] else ""
        
        paragraphs = re.split(r'\n\n+', section['content'].strip())
        
        chunks = []
        current_text = heading_text + parent_context
        current_tokens = self._count_tokens(current_text)
        
        for para in paragraphs:
            para_tokens = self._count_tokens(para)
            
            if current_tokens + para_tokens <= self.config.max_chunk_size:
                current_text += para + '\n\n'
                current_tokens += para_tokens
            else:
                if current_text.strip() != heading_text.strip():
                    chunks.append({
                        'text': current_text.strip(),
                        'heading_path': section['full_path'],
                        'level': section['level'],
                        'token_count': current_tokens,
                        'doc_context': doc_context,
                        'type': 'split_section',
                        'part': len(chunks) + 1
                    })
                
                current_text = heading_text + parent_context + para + '\n\n'
                current_tokens = self._count_tokens(current_text)
        
        if current_text.strip():
            chunks.append({
                'text': current_text.strip(),
                'heading_path': section['full_path'],
                'level': section['level'],
                'token_count': current_tokens,
                'doc_context': doc_context,
                'type': 'split_section',
                'part': len(chunks) + 1
            })
        
        return chunks
    
    def _optimize_chunks(self, chunks: List[Dict]) -> List[Dict]:
        """Merge very small chunks"""
        optimized = []
        i = 0
        
        while i < len(chunks):
            chunk = chunks[i]
            
            if (chunk['token_count'] < self.config.min_chunk_size and 
                i < len(chunks) - 1):
                
                next_chunk = chunks[i + 1]
                combined_text = chunk['text'] + '\n\n' + next_chunk['text']
                combined_tokens = self._count_tokens(combined_text)
                
                if combined_tokens <= self.config.max_chunk_size:
                    merged_chunk = {
                        'text': combined_text,
                        'heading_path': chunk['heading_path'],
                        'token_count': combined_tokens,
                        'doc_context': chunk['doc_context'],
                        'type': 'merged'
                    }
                    optimized.append(merged_chunk)
                    i += 2
                    continue
            
            optimized.append(chunk)
            i += 1
        
        return optimized
    
    def _format_section_text(self, section: Dict) -> str:
        """Format section with heading and context"""
        parts = []
        
        if section['parent_path']:
            parts.append(f"[Context: {section['parent_path']}]")
        
        if section['title'] and section['title'] != 'Introduction':
            heading_prefix = '#' * section['level']
            parts.append(f"{heading_prefix} {section['title']}")
        
        parts.append(section['content'].strip())
        
        return '\n\n'.join(parts)
    
    def _count_tokens(self, text: str) -> int:
        """Count tokens in text with truncation"""
        return len(self.tokenizer.encode(
            text, 
            add_special_tokens=False,
            truncation=True,
            max_length=512
        ))

In [7]:
# ============================================================================
# DOCUMENT PROCESSOR WITH IMAGE EXTRACTION
# ============================================================================

class DocumentProcessor:
    """Handles PDF processing with image extraction and analysis"""
    
    def __init__(self, config: RAGConfig, ollama_client: OllamaClient):
        self.config = config
        self.ollama = ollama_client
        self.chunker = MarkdownSemanticChunker(config)
        
        # Create images directory
        os.makedirs(config.images_dir, exist_ok=True)
    
    def _sanitize_utf8(self, text: str) -> str:
        """Sanitize text to remove invalid UTF-8 characters"""
        if not text:
            return text
        
        # Encode to UTF-8 with error handling, then decode back
        # This removes any invalid UTF-8 sequences
        try:
            # First try strict encoding
            return text.encode('utf-8', errors='ignore').decode('utf-8', errors='ignore')
        except Exception as e:
            print(f"    ⚠️  UTF-8 sanitization error: {e}")
            # Fallback: replace problematic characters
            return text.encode('ascii', errors='ignore').decode('ascii', errors='ignore')
    
    def pdf_to_markdown(self, pdf_path: str) -> str:
        """Convert PDF to Markdown using PyMuPDF4LLM"""
        markdown_text = pymupdf4llm.to_markdown(pdf_path)
        # Sanitize to remove invalid UTF-8
        return self._sanitize_utf8(markdown_text)
    
    def _group_nearby_rectangles(self, rects: List[fitz.Rect], proximity_threshold: float = 20) -> List[List[int]]:
        """Group rectangles that are close to each other"""
        if not rects:
            return []

        # Each rect gets assigned to a group
        groups = []
        assigned = [False] * len(rects)

        for i, rect in enumerate(rects):
            if assigned[i]:
                continue

            # Start a new group
            current_group = [i]
            assigned[i] = True

            # Find all rects that should be in this group
            changed = True
            while changed:
                changed = False
                for j, other_rect in enumerate(rects):
                    if assigned[j]:
                        continue

                    # Check if this rect is close to any rect in current group
                    for group_idx in current_group:
                        group_rect = rects[group_idx]

                        # Calculate distance between rectangles
                        # Expand each rect by proximity_threshold and check for intersection
                        expanded_group = fitz.Rect(
                            group_rect.x0 - proximity_threshold,
                            group_rect.y0 - proximity_threshold,
                            group_rect.x1 + proximity_threshold,
                            group_rect.y1 + proximity_threshold
                        )

                        if expanded_group.intersects(other_rect):
                            current_group.append(j)
                            assigned[j] = True
                            changed = True
                            break

            groups.append(current_group)

        return groups

    def extract_images_from_pdf(
        self,
        pdf_path: str,
        document_id: int,
        min_image_size: int = 50,  # Minimum width/height in pixels
        proximity_threshold: float = 20  # Group images within this distance (points)
    ) -> List[Dict]:
        """
        Extract images from PDF with intelligent grouping.
        Groups nearby images together to capture complete diagrams.
        """
        doc = fitz.open(pdf_path)
        images = []

        for page_num in range(len(doc)):
            page = doc[page_num]
            image_list = page.get_images(full=True)

            if not image_list:
                continue

            # Get bounding boxes for all images on this page
            image_bboxes = []
            for img_info in image_list:
                xref = img_info[0]
                # Get all instances of this image on the page
                rects = page.get_image_rects(xref)
                if rects:
                    for rect in rects:
                        # Check minimum size
                        width = rect.width
                        height = rect.height
                        if width >= min_image_size and height >= min_image_size:
                            image_bboxes.append({
                                'rect': rect,
                                'xref': xref,
                                'width': width,
                                'height': height
                            })

            if not image_bboxes:
                continue

            # Group nearby images
            rects_only = [bbox['rect'] for bbox in image_bboxes]
            groups = self._group_nearby_rectangles(rects_only, proximity_threshold)

            # Process each group
            for group_idx, group in enumerate(groups):
                if len(group) == 1:
                    # Single image - extract normally
                    bbox = image_bboxes[group[0]]
                    try:
                        base_image = doc.extract_image(bbox['xref'])
                        image_bytes = base_image["image"]
                        pil_image = PILImage.open(io.BytesIO(image_bytes))

                        # Save image
                        image_filename = f"doc{document_id}_page{page_num+1}_img{len(images)+1}.png"
                        image_path = os.path.join(self.config.images_dir, image_filename)

                        if pil_image.mode == 'RGBA':
                            pil_image = pil_image.convert('RGB')

                        pil_image.save(image_path, 'PNG')

                        images.append({
                            'page_number': page_num + 1,
                            'image_path': image_path,
                            'image_index': len(images),
                            'is_composite': False,
                            'bbox': bbox['rect']
                        })
                    except Exception as e:
                        print(f"    ⚠️  Failed to extract single image on page {page_num+1}: {e}")

                else:
                    # Multiple images grouped together - capture as screenshot
                    # Calculate bounding box that encompasses all images in group
                    union_rect = image_bboxes[group[0]]['rect']
                    for idx in group[1:]:
                        union_rect = union_rect | image_bboxes[idx]['rect']  # Union of rectangles

                    # Add some padding
                    padding = 5
                    union_rect = fitz.Rect(
                        max(0, union_rect.x0 - padding),
                        max(0, union_rect.y0 - padding),
                        min(page.rect.width, union_rect.x1 + padding),
                        min(page.rect.height, union_rect.y1 + padding)
                    )

                    try:
                        # Render this region as an image
                        mat = fitz.Matrix(2, 2)  # 2x zoom for better quality
                        pix = page.get_pixmap(matrix=mat, clip=union_rect)

                        # Convert to PIL Image
                        img_data = pix.tobytes("png")
                        pil_image = PILImage.open(io.BytesIO(img_data))

                        # Save composite image
                        image_filename = f"doc{document_id}_page{page_num+1}_composite{group_idx+1}.png"
                        image_path = os.path.join(self.config.images_dir, image_filename)

                        pil_image.save(image_path, 'PNG')

                        images.append({
                            'page_number': page_num + 1,
                            'image_path': image_path,
                            'image_index': len(images),
                            'is_composite': True,
                            'num_components': len(group),
                            'bbox': union_rect
                        })

                        print(f"    📊 Grouped {len(group)} images into composite on page {page_num+1}")

                    except Exception as e:
                        print(f"    ⚠️  Failed to create composite image on page {page_num+1}: {e}")

        doc.close()
        return images
    
    def analyze_images(
        self, 
        images: List[Dict],
        document_id: int,
        db_session
    ) -> List[int]:
        """Analyze images with vision model and save to database"""
        image_ids = []
        
        for idx, img_info in enumerate(images):
            print(f"    Analyzing image {idx+1} on page {img_info['page_number']}...", end=' ')
            start_time = time.time()
            
            # Analyze with vision model
            analysis = self.ollama.analyze_image(img_info['image_path'])
            
            # Save to database with UTF-8 sanitization
            image_record = Image(
                document_id=document_id,
                page_number=img_info['page_number'],
                image_path=img_info['image_path'],
                description=self._sanitize_utf8(analysis['description']),
                image_type=self._sanitize_utf8(analysis['type']),
                ocr_text=self._sanitize_utf8(analysis['ocr_text'])
            )
            db_session.add(image_record)
            db_session.flush()
            
            image_ids.append(image_record.id)
            
            elapsed = time.time() - start_time
            print(f"✓ ({elapsed:.1f}s)")
        
        db_session.commit()
        return image_ids
    
    def enrich_chunks_with_images(
        self,
        chunks: List[Dict],
        images_data: List[Dict],
        db_session
    ) -> List[Dict]:
        """Add image context (description + OCR text) to relevant chunks for better search accuracy"""
        
        enriched_chunks = []
        
        for chunk in chunks:
            chunk_copy = chunk.copy()
            
            # Find images that might be relevant to this chunk
            # Simple heuristic: chunks that mention visual content keywords
            relevant_images = []
            
            for img in images_data:
                if any(keyword in chunk['text'].lower() for keyword in 
                       ['figure', 'image', 'diagram', 'chart', 'screenshot', 'see below', 'shown in']):
                    relevant_images.append(img)
            
            if relevant_images:
                # Build comprehensive image context including OCR text
                image_context = "\n\n[Images in this section]:\n"
                image_metadata = []
                
                for img in relevant_images:
                    # Add type and description
                    image_context += f"- {img['type'].capitalize()}: {img['description']}\n"
                    
                    # CRITICAL: Add OCR text if available (makes text in images searchable!)
                    if img.get('ocr_text') and img['ocr_text'].strip():
                        image_context += f"  Text visible in image: {img['ocr_text']}\n"
                    
                    image_metadata.append({
                        'path': img['image_path'],
                        'description': img['description'],
                        'type': img['type'],
                        'ocr_text': img.get('ocr_text', '')
                    })
                
                chunk_copy['text'] = self._sanitize_utf8(chunk['text'] + image_context)
                chunk_copy['has_images'] = True
                chunk_copy['image_paths'] = [img['image_path'] for img in relevant_images]
                chunk_copy['image_metadata'] = image_metadata
            else:
                chunk_copy['text'] = self._sanitize_utf8(chunk['text'])
                chunk_copy['has_images'] = False
            
            enriched_chunks.append(chunk_copy)
        
        return enriched_chunks
    
    def process_document(
        self, 
        pdf_path: str,
        db_session
    ) -> Tuple[List[Dict], int]:
        """Complete processing pipeline"""
        print(f"\n{'='*60}")
        print(f"Processing: {pdf_path}")
        print(f"{'='*60}")
        
        # Step 1: Convert to markdown
        print("\n[Step 1/5] Converting PDF to Markdown...", end=' ')
        start_time = time.time()
        markdown_text = self.pdf_to_markdown(pdf_path)
        elapsed = time.time() - start_time
        print(f"✓ {elapsed:.2f}s")
        print(f"  • Extracted {len(markdown_text):,} characters")
        
        # Create document record
        doc = Document(
            filename=os.path.basename(pdf_path),
            status='processing'
        )
        db_session.add(doc)
        db_session.commit()
        
        # Step 2: Extract and analyze images
        print("\n[Step 2/5] Extracting and analyzing images...")
        start_time = time.time()
        
        images = self.extract_images_from_pdf(pdf_path, doc.id)
        
        if images:
            image_ids = self.analyze_images(images, doc.id, db_session)
            
            # Get image data for enrichment
            images_data = []
            for img_id in image_ids:
                img_record = db_session.query(Image).filter_by(id=img_id).first()
                if img_record:
                    images_data.append({
                        'image_path': img_record.image_path,
                        'description': img_record.description,
                        'type': img_record.image_type,
                        'ocr_text': img_record.ocr_text
                    })
        else:
            images_data = []
        
        elapsed = time.time() - start_time
        print(f"  ✓ Completed in {elapsed:.2f}s")
        print(f"  • Extracted {len(images)} images")
        if images:
            print(f"  • Vision analysis: ✓")
        
        # Step 3: Markdown-aware semantic chunking
        print("\n[Step 3/5] Markdown-aware semantic chunking...", end=' ')
        start_time = time.time()
        doc_context = f"Document: {os.path.basename(pdf_path)}\n\n{markdown_text[:500]}"
        chunks = self.chunker.chunk_markdown(markdown_text, doc_context)
        elapsed = time.time() - start_time
        print(f"✓ {elapsed:.2f}s")
        print(f"  • Created {len(chunks)} semantic chunks")
        
        # Step 4: Enrich chunks with image context (INCLUDING OCR TEXT!)
        print("\n[Step 4/5] Enriching chunks with image context...", end=' ')
        start_time = time.time()
        if images_data:
            chunks = self.enrich_chunks_with_images(chunks, images_data, db_session)
            chunks_with_images = sum(1 for c in chunks if c.get('has_images', False))
            elapsed = time.time() - start_time
            print(f"✓ {elapsed:.2f}s")
            print(f"  • {chunks_with_images} chunks enriched with image context + OCR text")
        else:
            # Still sanitize even if no images
            for chunk in chunks:
                chunk['text'] = self._sanitize_utf8(chunk['text'])
            elapsed = time.time() - start_time
            print(f"✓ {elapsed:.2f}s")
            print(f"  • No images to enrich")
        
        # Step 5: Save to database
        print("\n[Step 5/5] Saving chunks to database...", end=' ')
        start_time = time.time()
        for idx, chunk in enumerate(chunks):
            chunk_record = Chunk(
                document_id=doc.id,
                chunk_index=idx,
                text=self._sanitize_utf8(chunk['text']),  # Sanitize before saving
                heading_path=chunk.get('heading_path', ''),
                token_count=chunk.get('token_count', 0),
                has_images=chunk.get('has_images', False),
                chunk_metadata=json.dumps({
                    k: v for k, v in chunk.items() 
                    if k not in ['text', 'heading_path', 'token_count', 'has_images']
                })
            )
            db_session.add(chunk_record)
        
        doc.status = 'indexed'
        db_session.commit()
        elapsed = time.time() - start_time
        print(f"✓ {elapsed:.2f}s")
        
        return chunks, doc.id

In [8]:
# ============================================================================
# JINA COLBERT V2 RETRIEVER (NO RAGATOUILLE!)
# ============================================================================

class JinaColBERTRetriever:
    """Direct implementation of Jina ColBERT v2 (no RAGatouille dependency)"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.model = SentenceTransformer(
            config.embedding_model,
            trust_remote_code=True,
            device=config.device
        )
        # Set max sequence length to avoid truncation warnings
        self.model.max_seq_length = 512
        self.corpus_embeddings = None
        self.corpus = None
    
    def index(self, corpus: List[str]) -> None:
        """Index corpus with ColBERT embeddings"""
        self.corpus = corpus
        
        print(f"  Encoding {len(corpus)} documents...")
        
        # Encode corpus (this gives us token-level embeddings)
        # Truncate long sequences to avoid errors
        self.corpus_embeddings = self.model.encode(
            corpus,
            show_progress_bar=True,
            convert_to_tensor=True,
            batch_size=8  # Smaller batch size for stability
        )
        
        # Save to disk
        os.makedirs(self.config.colbert_index_path, exist_ok=True)
        torch.save({
            'embeddings': self.corpus_embeddings,
            'corpus': corpus
        }, os.path.join(self.config.colbert_index_path, 'index.pt'))
    
    def load(self) -> None:
        """Load index from disk"""
        index_file = os.path.join(self.config.colbert_index_path, 'index.pt')
        data = torch.load(index_file, map_location=self.config.device)
        self.corpus_embeddings = data['embeddings']
        self.corpus = data['corpus']
    
    def search(self, query: str, k: int = 10) -> List[Dict]:
        """Search using MaxSim scoring"""
        if not self.corpus or len(self.corpus) == 0:
            return []
        
        # Encode query
        query_embedding = self.model.encode(
            query,
            convert_to_tensor=True
        )
        
        # Compute MaxSim scores
        scores = self._maxsim_score(query_embedding, self.corpus_embeddings)
        
        # Handle single item corpus
        if len(self.corpus) == 1:
            return [{
                'document_id': 0,
                'score': float(scores.item() if scores.dim() == 0 else scores[0]),
                'text': self.corpus[0]
            }]
        
        # Get top-k
        k = min(k, len(scores))
        top_k_indices = torch.topk(scores, k=k).indices
        
        results = []
        for idx in top_k_indices:
            results.append({
                'document_id': int(idx),
                'score': float(scores[idx]),
                'text': self.corpus[idx] if self.corpus else None
            })
        
        return results
    
    def rerank(self, query: str, documents: List[str], k: int = 10) -> List[Dict]:
        """Rerank documents with more accurate scoring"""
        if not documents:
            return []
        
        # Encode query and documents
        query_embedding = self.model.encode(query, convert_to_tensor=True)
        doc_embeddings = self.model.encode(
            documents, 
            convert_to_tensor=True,
            batch_size=8  # Smaller batch size for stability
        )
        
        # Compute MaxSim scores
        scores = self._maxsim_score(query_embedding, doc_embeddings)
        
        # Handle single document
        if len(documents) == 1:
            return [{
                'result_index': 0,
                'score': float(scores.item() if scores.dim() == 0 else scores[0]),
                'rank': 1,
                'text': documents[0]
            }]
        
        # Sort by score
        sorted_indices = torch.argsort(scores, descending=True)
        
        results = []
        for rank, idx in enumerate(sorted_indices[:k]):
            results.append({
                'result_index': int(idx),
                'score': float(scores[idx]),
                'rank': rank + 1,
                'text': documents[idx]
            })
        
        return results
    
    def _maxsim_score(
        self, 
        query_embedding: torch.Tensor, 
        doc_embeddings: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute MaxSim score between query and documents
        
        MaxSim: For each query token, find max similarity with all doc tokens,
        then average across query tokens
        """
        # Ensure proper dimensions
        if query_embedding.dim() == 1:
            query_embedding = query_embedding.unsqueeze(0)
        if doc_embeddings.dim() == 1:
            doc_embeddings = doc_embeddings.unsqueeze(0)
        
        # For 2D embeddings (single vector per doc), compute cosine similarity directly
        if query_embedding.dim() == 2 and doc_embeddings.dim() == 2:
            # Normalize embeddings
            query_norm = torch.nn.functional.normalize(query_embedding, p=2, dim=1)
            doc_norm = torch.nn.functional.normalize(doc_embeddings, p=2, dim=1)
            
            # Compute cosine similarity
            scores = torch.mm(query_norm, doc_norm.t())
            
            # Return as 1D tensor
            return scores.squeeze(0) if scores.size(0) == 1 else scores.squeeze()
        
        # For 3D embeddings (token-level), use mean pooling
        if query_embedding.dim() == 3:
            query_vec = query_embedding.mean(dim=1)
        else:
            query_vec = query_embedding
            
        if doc_embeddings.dim() == 3:
            doc_vec = doc_embeddings.mean(dim=1)
        else:
            doc_vec = doc_embeddings
        
        # Normalize
        query_vec = torch.nn.functional.normalize(query_vec, p=2, dim=-1)
        doc_vec = torch.nn.functional.normalize(doc_vec, p=2, dim=-1)
        
        # Compute cosine similarity
        if query_vec.dim() == 1:
            query_vec = query_vec.unsqueeze(0)
        if doc_vec.dim() == 1:
            doc_vec = doc_vec.unsqueeze(0)
            
        scores = torch.mm(query_vec, doc_vec.t())
        
        # Return as 1D tensor
        return scores.squeeze(0) if scores.size(0) == 1 else scores.squeeze()

In [9]:
# ============================================================================
# DUAL INDEXER (BM25s + Jina ColBERT)
# ============================================================================

class DualIndexer:
    """Manages BM25s and Jina ColBERT v2 indexes"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.bm25_retriever = None
        self.colbert_retriever = JinaColBERTRetriever(config)
    
    def build_bm25_index(self, corpus: List[str]) -> None:
        """Build BM25s index"""
        print("\n[BM25s] Building lexical search index...", end=' ')
        start_time = time.time()
        
        # Create stemmer
        stemmer = Stemmer.Stemmer("english")
        
        # Tokenize corpus
        corpus_tokens = bm25s.tokenize(
            corpus, 
            stopwords="en",
            stemmer=stemmer
        )
        
        self.bm25_retriever = bm25s.BM25()
        self.bm25_retriever.index(corpus_tokens)
        
        os.makedirs(self.config.bm25_index_path, exist_ok=True)
        self.bm25_retriever.save(self.config.bm25_index_path)
        
        elapsed = time.time() - start_time
        print(f"✓ {elapsed:.2f}s")
    
    def build_colbert_index(self, corpus: List[str]) -> None:
        """Build Jina ColBERT v2 index"""
        print("\n[ColBERT] Building semantic search index...")
        start_time = time.time()
        
        self.colbert_retriever.index(corpus)
        
        elapsed = time.time() - start_time
        print(f"  ✓ {elapsed:.2f}s")
    
    def load_indexes(self) -> None:
        """Load indexes from disk"""
        self.bm25_retriever = bm25s.BM25.load(self.config.bm25_index_path)
        self.colbert_retriever.load()

In [10]:
# ============================================================================
# HYBRID RETRIEVER WITH RRF AND RERANKING
# ============================================================================

class HybridRetriever:
    """Three-stage retrieval: BM25s + ColBERT + ColBERT Reranking"""
    
    def __init__(self, config: RAGConfig, indexer: DualIndexer, db_session, corpus_to_chunk_id: List[int] = None):
        self.config = config
        self.indexer = indexer
        self.db_session = db_session
        self.stemmer = Stemmer.Stemmer("english")
        # CRITICAL: Mapping from corpus index to database chunk ID
        self.corpus_to_chunk_id = corpus_to_chunk_id or []
    
    def retrieve(self, query: str, top_k_final: int = None) -> List[Dict]:
        """Three-stage hybrid retrieval with detailed scoring"""
        if top_k_final is None:
            top_k_final = self.config.final_top_k
        
        print(f"\n🔍 Retrieving relevant chunks...")
        
        # Get corpus size to adjust k values
        corpus_size = len(self.indexer.colbert_retriever.corpus) if self.indexer.colbert_retriever.corpus else 0
        
        # Adjust k values based on corpus size
        bm25_k = min(self.config.bm25_top_k, corpus_size) if corpus_size > 0 else self.config.bm25_top_k
        colbert_k = min(self.config.colbert_top_k, corpus_size) if corpus_size > 0 else self.config.colbert_top_k
        
        print(f"   • Corpus size: {corpus_size}, using k={bm25_k} for retrieval")
        
        # Stage 1: BM25s
        start = time.time()
        bm25_results = self._bm25_search(query, k=bm25_k)
        bm25_time = time.time() - start
        print(f"   • BM25s: {bm25_time:.3f}s ({len(bm25_results)} results)")
        
        # Stage 2: ColBERT
        start = time.time()
        colbert_results = self._colbert_search(query, k=colbert_k)
        colbert_time = time.time() - start
        print(f"   • ColBERT: {colbert_time:.3f}s ({len(colbert_results)} results)")
        
        # Fusion
        start = time.time()
        fused_results = self._reciprocal_rank_fusion(bm25_results, colbert_results)
        candidates = fused_results[:min(50, len(fused_results))]
        fusion_time = time.time() - start
        print(f"   • Fusion: {fusion_time:.3f}s ({len(candidates)} candidates)")
        
        # Fetch chunks - USING THE MAPPING!
        start = time.time()
        candidate_corpus_indices = [r['corpus_index'] for r in candidates]
        candidate_chunks = self._fetch_chunks_from_db(candidate_corpus_indices)
        
        # PRESERVE INTERMEDIATE SCORES
        # Map corpus_index to intermediate scores
        score_map = {}
        for bm25_result in bm25_results:
            idx = bm25_result['corpus_index']
            if idx not in score_map:
                score_map[idx] = {}
            score_map[idx]['bm25_score'] = bm25_result['score']
        
        for colbert_result in colbert_results:
            idx = colbert_result['corpus_index']
            if idx not in score_map:
                score_map[idx] = {}
            score_map[idx]['colbert_score'] = colbert_result['score']
        
        for fused_result in candidates:
            idx = fused_result['corpus_index']
            if idx in score_map:
                score_map[idx]['rrf_score'] = fused_result['rrf_score']
        
        # Add intermediate scores to chunks
        for i, chunk in enumerate(candidate_chunks):
            corpus_idx = candidate_corpus_indices[i]
            if corpus_idx in score_map:
                chunk['intermediate_scores'] = score_map[corpus_idx]
        
        fetch_time = time.time() - start
        print(f"   • Fetch: {fetch_time:.3f}s ({len(candidate_chunks)} chunks)")
        
        # Stage 3: Rerank
        start = time.time()
        final_k = min(top_k_final, len(candidate_chunks))
        reranked_results = self._colbert_rerank(query, candidate_chunks, top_k=final_k)
        rerank_time = time.time() - start
        print(f"   • Rerank: {rerank_time:.3f}s (top {len(reranked_results)})")
        
        total_time = bm25_time + colbert_time + fusion_time + fetch_time + rerank_time
        print(f"   ✓ Total retrieval: {total_time:.3f}s")
        
        return reranked_results
    
    def _bm25_search(self, query: str, k: int) -> List[Dict]:
        """Stage 1: BM25s lexical search"""
        query_tokens = bm25s.tokenize(
            query, 
            stopwords="en",
            stemmer=self.stemmer
        )
        
        results, scores = self.indexer.bm25_retriever.retrieve(query_tokens, k=k)
        
        return [
            {'corpus_index': int(results[0][i]), 'score': float(scores[0][i]), 'source': 'bm25'}
            for i in range(len(results[0]))
        ]
    
    def _colbert_search(self, query: str, k: int) -> List[Dict]:
        """Stage 2: ColBERT semantic search"""
        results = self.indexer.colbert_retriever.search(query=query, k=k)
        return [
            {'corpus_index': r['document_id'], 'score': r['score'], 'source': 'colbert'}
            for r in results
        ]
    
    def _reciprocal_rank_fusion(
        self, 
        bm25_results: List[Dict], 
        colbert_results: List[Dict],
        k: int = 60
    ) -> List[Dict]:
        """RRF fusion"""
        scores = {}
        
        for rank, result in enumerate(bm25_results, 1):
            corpus_idx = result['corpus_index']
            scores[corpus_idx] = scores.get(corpus_idx, 0) + (1 / (k + rank))
        
        for rank, result in enumerate(colbert_results, 1):
            corpus_idx = result['corpus_index']
            scores[corpus_idx] = scores.get(corpus_idx, 0) + (1 / (k + rank))
        
        sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        return [{'corpus_index': idx, 'rrf_score': score} for idx, score in sorted_results]
    
    def _fetch_chunks_from_db(self, corpus_indices: List[int]) -> List[Dict]:
        """Fetch chunks from database using corpus index -> chunk ID mapping"""
        chunks = []
        
        for corpus_idx in corpus_indices:
            # Convert corpus index to database chunk ID
            if corpus_idx < len(self.corpus_to_chunk_id):
                chunk_id = self.corpus_to_chunk_id[corpus_idx]
                
                # Fetch from database using the actual chunk ID
                chunk = self.db_session.query(Chunk).filter_by(id=chunk_id).first()
                if chunk:
                    chunks.append({
                        'chunk_id': chunk.id,
                        'text': chunk.text,
                        'document_id': chunk.document_id,
                        'heading_path': chunk.heading_path,
                        'has_images': chunk.has_images,
                        'metadata': json.loads(chunk.chunk_metadata) if chunk.chunk_metadata else {}
                    })
                else:
                    print(f"  ⚠️ Chunk ID {chunk_id} not found in database")
            else:
                print(f"  ⚠️ Corpus index {corpus_idx} out of range (max: {len(self.corpus_to_chunk_id)-1})")
        
        return chunks
    
    def _colbert_rerank(self, query: str, chunks: List[Dict], top_k: int) -> List[Dict]:
        """Stage 3: ColBERT reranking with score preservation"""
        if not chunks:
            return []
        
        documents = [chunk['text'] for chunk in chunks]
        reranked_results = self.indexer.colbert_retriever.rerank(query=query, documents=documents, k=top_k)
        
        final_results = []
        for result in reranked_results:
            original_chunk = chunks[result['result_index']]
            intermediate_scores = original_chunk.get('intermediate_scores', {})
            
            final_results.append({
                'chunk_id': original_chunk['chunk_id'],
                'text': original_chunk['text'],
                'document_id': original_chunk['document_id'],
                'heading_path': original_chunk.get('heading_path', ''),
                'has_images': original_chunk.get('has_images', False),
                'metadata': original_chunk['metadata'],
                'score': result['score'],  # Final ColBERT rerank score (cosine similarity)
                'rank': result['rank'],
                'bm25_score': intermediate_scores.get('bm25_score', 0.0),
                'colbert_score': intermediate_scores.get('colbert_score', 0.0),
                'rrf_score': intermediate_scores.get('rrf_score', 0.0)
            })
        return final_results

In [11]:
# ============================================================================
# RAG CHATBOT WITH STREAMING
# ============================================================================

class RAGChatbot:
    """Complete RAG chatbot with Ollama and streaming support"""
    
    def __init__(self, config: RAGConfig, retriever: HybridRetriever, ollama_client: OllamaClient):
        self.config = config
        self.retriever = retriever
        self.ollama = ollama_client
        self.conversation_history = []
    
    def chat(self, query: str, stream: bool = True) -> Dict:
        """Process user query and generate response with streaming"""
        # Retrieve relevant chunks
        retrieved_chunks = self.retriever.retrieve(query)
        
        # Build context
        context = self._build_context(retrieved_chunks)
        
        # Generate response with streaming
        if stream:
            print(f"\n🤖 Generating response (streaming)...\n")
        else:
            print(f"\n🤖 Generating response...", end=' ')
        
        start_time = time.time()
        
        self.conversation_history.append({
            'role': 'user',
            'content': query
        })
        
        response = self.ollama.chat(
            messages=self.conversation_history,
            context=context,
            stream=stream
        )
        
        elapsed = time.time() - start_time
        
        if not stream:
            print(f"✓ {elapsed:.1f}s")
        else:
            print(f"\n⏱️  Response generated in {elapsed:.1f}s")
        
        self.conversation_history.append({
            'role': 'assistant',
            'content': response
        })
        
        return {
            'response': response,
            'sources': self._format_sources(retrieved_chunks),
            'retrieved_chunks': len(retrieved_chunks)
        }
    
    def _build_context(self, chunks: List[Dict]) -> str:
        """Build context from retrieved chunks"""
        context_parts = []
        
        for i, chunk in enumerate(chunks, 1):
            heading = f" ({chunk['heading_path']})" if chunk.get('heading_path') else ""
            
            # Add image info if present
            image_info = ""
            if chunk.get('has_images') and chunk.get('metadata', {}).get('image_paths'):
                num_images = len(chunk['metadata']['image_paths'])
                image_info = f" [Contains {num_images} image(s)]"
            
            context_parts.append(f"[Source {i}{heading}{image_info}]\n{chunk['text']}\n")
        
        return "\n".join(context_parts)
    
    def _format_sources(self, chunks: List[Dict]) -> List[Dict]:
        """Format source citations with full text, image paths, and ALL scores"""
        sources = []
        
        for i, chunk in enumerate(chunks):
            source = {
                'source_id': i + 1,
                'chunk_id': chunk['chunk_id'],
                'document_id': chunk['document_id'],
                'heading': chunk.get('heading_path', ''),
                'score': chunk['score'],  # Final ColBERT rerank score
                'bm25_score': chunk.get('bm25_score', 0.0),
                'colbert_score': chunk.get('colbert_score', 0.0),
                'rrf_score': chunk.get('rrf_score', 0.0),
                'has_images': chunk.get('has_images', False),
                'text': chunk['text'],  # Include full text
                'preview': chunk['text'][:200] + "..." if len(chunk['text']) > 200 else chunk['text']
            }
            
            # Add image paths if available
            if chunk.get('has_images') and chunk.get('metadata'):
                image_paths = chunk['metadata'].get('image_paths', [])
                source['image_paths'] = image_paths
            
            sources.append(source)
        
        return sources
    
    def clear_history(self):
        """Clear conversation history"""
        self.conversation_history = []
        print("🗑️  Conversation history cleared")

In [12]:
class RAGApplication:
    """Main application orchestrator"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        
        # Database setup
        db_url = f"sqlite:///{config.db_path}"
        self.engine = create_engine(db_url)
        Base.metadata.create_all(self.engine)
        Session = sessionmaker(bind=self.engine)
        self.db_session = Session()
        
        # Initialize Ollama client
        self.ollama = OllamaClient(config)
        
        # Initialize components
        self.processor = DocumentProcessor(config, self.ollama)
        self.indexer = DualIndexer(config)
        self.retriever = None
        self.chatbot = None
        
        # CRITICAL: Store mapping between corpus index and chunk IDs
        self.corpus_to_chunk_id = []  # Maps corpus index -> database chunk ID
    
    def check_ollama(self) -> bool:
        """Check if Ollama is running"""
        try:
            response = requests.get(f"{self.config.ollama_url}/api/tags", timeout=5)
            return response.status_code == 200
        except:
            return False
    
    def index_documents(self, pdf_paths: List[str]) -> None:
        """Index PDF documents"""
        
        if not self.check_ollama():
            print("❌ Ollama is not running!")
            print("Please start Ollama: ollama serve")
            return
        
        all_chunks = []
        
        for pdf_path in pdf_paths:
            chunks, doc_id = self.processor.process_document(pdf_path, self.db_session)
            all_chunks.extend(chunks)
        
        print(f"\n{'='*60}")
        print("Building Indexes")
        print(f"{'='*60}")
        
        # Build corpus and mapping
        # CRITICAL FIX: Store the mapping between corpus index and database chunk IDs
        all_db_chunks = self.db_session.query(Chunk).order_by(Chunk.id).all()
        corpus = []
        self.corpus_to_chunk_id = []
        
        for chunk in all_db_chunks:
            corpus.append(chunk.text)
            self.corpus_to_chunk_id.append(chunk.id)
        
        print(f"  • Corpus: {len(corpus)} chunks")
        print(f"  • Chunk ID mapping: {len(self.corpus_to_chunk_id)} entries")
        
        # Build indexes
        self.indexer.build_bm25_index(corpus)
        self.indexer.build_colbert_index(corpus)
        
        # Save the mapping to disk for later use
        import pickle
        mapping_path = os.path.join(self.config.base_dir, "indexes", "corpus_mapping.pkl")
        os.makedirs(os.path.dirname(mapping_path), exist_ok=True)
        with open(mapping_path, 'wb') as f:
            pickle.dump(self.corpus_to_chunk_id, f)
        
        print(f"\n✅ Document indexed successfully!")
    
    def initialize_chatbot(self) -> None:
        """Initialize chatbot with existing indexes"""
        
        if not self.check_ollama():
            print("❌ Ollama is not running!")
            print("Please start Ollama: ollama serve")
            return
        
        print("Loading indexes...")
        self.indexer.load_indexes()
        
        # Load the corpus-to-chunk-id mapping
        import pickle
        mapping_path = os.path.join(self.config.base_dir, "indexes", "corpus_mapping.pkl")
        try:
            with open(mapping_path, 'rb') as f:
                self.corpus_to_chunk_id = pickle.load(f)
            print(f"  • Loaded {len(self.corpus_to_chunk_id)} chunk ID mappings")
        except FileNotFoundError:
            print("  ⚠️  Warning: No corpus mapping found. Please re-index your documents.")
            self.corpus_to_chunk_id = []
        
        self.retriever = HybridRetriever(self.config, self.indexer, self.db_session, self.corpus_to_chunk_id)
        self.chatbot = RAGChatbot(self.config, self.retriever, self.ollama)
        
        print("✅ Chatbot initialized and ready!")
    
    def chat(self, query: str) -> Dict:
        """Chat interface"""
        if not self.chatbot:
            raise RuntimeError("Chatbot not initialized. Call initialize_chatbot() first.")
        
        return self.chatbot.chat(query)
    
    def _filter_relevant_images(self, query: str, image_paths: List[str], chunk_text: str) -> List[str]:
        """Filter images to only show those DIRECTLY relevant to the user's query - STRICT filtering"""
        if not image_paths:
            return []
        
        relevant_images = []
        
        # Extract meaningful query keywords (remove stop words)
        stop_words = {'what', 'is', 'are', 'the', 'a', 'an', 'how', 'why', 'when', 'where', 
                      'can', 'could', 'would', 'should', 'do', 'does', 'did', 'of', 'in', 'on',
                      'for', 'to', 'with', 'by', 'from', 'at', 'about', 'as', 'into', 'through',
                      'diagram', 'chart', 'figure', 'image', 'screenshot', 'show', 'me', 'please'}
        
        query_lower = query.lower()
        query_words = [w for w in query_lower.split() if w not in stop_words and len(w) > 2]
        
        if not query_words:
            return []  # No meaningful query words, don't show images
        
        # Get image metadata from database
        for img_path in image_paths:
            # Extract just the filename for DB lookup
            img_filename = os.path.basename(img_path)
            
            # Look up image in database to get description
            img_record = self.db_session.query(Image).filter(
                Image.image_path.like(f"%{img_filename}")
            ).first()
            
            if img_record:
                # Combine all image metadata
                desc_lower = (img_record.description or "").lower()
                img_type_lower = (img_record.image_type or "").lower()
                ocr_lower = (img_record.ocr_text or "").lower()
                
                # Create searchable text from image
                image_text = f"{desc_lower} {img_type_lower} {ocr_lower}"
                image_words = [w for w in image_text.split() if w not in stop_words and len(w) > 2]
                
                # Calculate meaningful overlap
                query_set = set(query_words)
                image_set = set(image_words)
                overlap = query_set.intersection(image_set)
                
                # STRICT CRITERIA: Need at least 3 meaningful word overlaps
                # This ensures the image is actually about what the user asked
                if len(overlap) >= 3:
                    relevant_images.append(img_path)
                    # print(f"  DEBUG: Image matched with {len(overlap)} overlaps: {overlap}")
        
        return relevant_images
    
    def _display_chunk_with_images(self, chunk_text: str, image_paths: List[str] = None) -> None:
        """Display chunk text and associated images"""
        from IPython.display import display, Image as IPImage
        
        # Display chunk text
        if chunk_text:
            print(f"{chunk_text}\n")
        
        # Display images if available
        if image_paths:
            print(f"  📷 Relevant Images ({len(image_paths)}):")
            for img_path in image_paths:
                if os.path.exists(img_path):
                    try:
                        display(IPImage(filename=img_path, width=400))
                        print(f"  └─ {os.path.basename(img_path)}\n")
                    except Exception as e:
                        print(f"  └─ ⚠️ Could not display {os.path.basename(img_path)}: {e}\n")
                else:
                    print(f"  └─ ⚠️ Image not found: {os.path.basename(img_path)}\n")
    
    def interactive_chat(self) -> None:
        """Interactive chat loop"""
        print("\n" + "="*60)
        print("RAG Chatbot - Interactive Mode")
        print("="*60)
        print("Type your questions (or 'exit' to quit, 'clear' to clear history)\n")
        
        while True:
            try:
                user_input = input("You: ").strip()
                
                if not user_input:
                    continue
                
                if user_input.lower() in ['exit', 'quit']:
                    print("\nGoodbye! 👋")
                    break
                
                if user_input.lower() == 'clear':
                    self.chatbot.clear_history()
                    continue
                
                result = self.chat(user_input)
                print(f"\nAssistant: {result['response']}\n")
                
                # Show retrieved chunks with ALL SCORES
                if result['sources']:
                    print(f"\n{'='*60}")
                    print(f"📊 Retrieved Chunks with Similarity Scores ({len(result['sources'])})")
                    print(f"{'='*60}\n")
                    
                    for idx, src in enumerate(result['sources'], 1):
                        print(f"┌─ Chunk {idx} {'─'*50}")
                        
                        # Show ALL retrieval scores
                        print(f"│ 🎯 Final Score (ColBERT Rerank): {src['score']:.4f}")
                        print(f"│ 📈 Intermediate Scores:")
                        print(f"│    • BM25 (lexical):      {src.get('bm25_score', 0.0):.4f}")
                        print(f"│    • ColBERT (semantic):  {src.get('colbert_score', 0.0):.4f}")
                        print(f"│    • RRF (fusion):        {src.get('rrf_score', 0.0):.4f}")
                        
                        if src['heading']:
                            print(f"│ 📍 Section: {src['heading']}")
                        
                        if src['has_images']:
                            print(f"│ 🖼️  Contains Images: Yes")
                        
                        print(f"│")
                        print(f"│ 📄 Text:")
                        
                        # Display chunk text (show first 300 chars as preview)
                        chunk_text = src.get('text', src.get('preview', ''))
                        
                        # Show preview
                        if len(chunk_text) > 300:
                            print(f"│ {chunk_text[:300]}...")
                            print(f"│ [Truncated - {len(chunk_text)} total characters]")
                        else:
                            print(f"│ {chunk_text}")
                        
                        # Filter and display only STRICTLY RELEVANT images
                        if src['has_images'] and src.get('image_paths'):
                            # Filter images based on query relevance with STRICT criteria
                            relevant_images = self._filter_relevant_images(
                                user_input, 
                                src['image_paths'], 
                                chunk_text
                            )
                            
                            if relevant_images:
                                print(f"│")
                                print(f"│ [Showing {len(relevant_images)}/{len(src['image_paths'])} images matching your query]")
                                self._display_chunk_with_images("", relevant_images)
                            else:
                                print(f"│")
                                print(f"│ [This chunk has images, but none directly match your specific query]")
                        
                        print(f"└{'─'*60}\n")
                    
                    print()
            
            except KeyboardInterrupt:
                print("\n\nGoodbye! 👋")
                break
            except Exception as e:
                print(f"\n❌ Error: {e}\n")
                import traceback
                traceback.print_exc()
    
    def print_stats(self) -> None:
        """Print database statistics"""
        doc_count = self.db_session.query(Document).count()
        chunk_count = self.db_session.query(Chunk).count()
        image_count = self.db_session.query(Image).count()
        
        print(f"\n📊 Database Statistics:")
        print(f"   • Documents: {doc_count}")
        print(f"   • Chunks: {chunk_count}")
        print(f"   • Images: {image_count}")

In [None]:
# Initialize config with SEPARATE models for vision and chat
# Vision: gemma3:4b (multimodal, for analyzing images)
# Chat: gemma3:4b (FASTER - recommended for 16GB RAM Mac Mini M4)
# Note: gpt-oss:20b is available but VERY slow. Use only if you need maximum quality.
config = RAGConfig(chat_model='llama3.2:3b')  # Changed from gpt-oss:20b to gemma3:4b for better performance
app = RAGApplication(config)

# Check Ollama
if not app.check_ollama():
    print("❌ Ollama is not running!")
    print("\nTo start Ollama:")
    print("  1. Open a terminal")
    print("  2. Run: ollama serve")
    print("  3. Keep that terminal open")
    print("\nThen run this cell again.")
else:
    # Simple menu with proper exit handling
    exit_program = False
    
    while not exit_program:
        print("\n" + "="*50)
        print("RAG Chatbot - Choose an option:")
        print("1. Upload and index a PDF")
        print("2. Start interactive chat")
        print("3. Show database statistics")
        print("4. Exit")
        
        choice = input("\nEnter your choice (1-4): ").strip()
        
        if choice == '1':
            file_path = input("Enter the path to your PDF file: ").strip()
            if os.path.exists(file_path):
                app.index_documents([file_path])
            else:
                print(f"Error: File not found at {file_path}")
                
        elif choice == '2':
            app.initialize_chatbot()
            app.interactive_chat()
            # Back to main menu after chat exits
            print("\n[Returned to main menu]")
            
        elif choice == '3':
            app.print_stats()
            
        elif choice == '4':
            print("\n" + "="*50)
            print("Goodbye! 👋")
            print("="*50)
            exit_program = True
            
        else:
            print("Invalid choice. Please enter a number between 1-4.")
    
    print("\n✅ Program exited successfully.")

No sentence-transformers model found with name jinaai/jina-colbert-v2. Creating a new one with mean pooling.
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!



RAG Chatbot - Choose an option:
1. Upload and index a PDF
2. Start interactive chat
3. Show database statistics
4. Exit
Loading indexes...
  • Loaded 3 chunk ID mappings
✅ Chatbot initialized and ready!

RAG Chatbot - Interactive Mode
Type your questions (or 'exit' to quit, 'clear' to clear history)


🔍 Retrieving relevant chunks...
   • Corpus size: 3, using k=3 for retrieval


Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

   • BM25s: 0.089s (3 results)
   • ColBERT: 0.452s (3 results)
   • Fusion: 0.000s (3 candidates)
   • Fetch: 0.004s (3 chunks)
   • Rerank: 0.303s (top 3)
   ✓ Total retrieval: 0.848s

🤖 Generating response (streaming)...

According to Source 1 (**What is Claude Skills** > How Claude Skills Can Accelerate Our AI Development Work), Claude Skills can accelerate your AI development work by:

* Creating **reusable Skills** for common development scenarios, such as:
	+ "AI-Systems-Standards" Skill (Python coding conventions, preferred libraries and frameworks, testing and evaluation patterns, deployment standards)
	+ "KAI-Project" Skill (KAI's architecture and components, retrieval strategies, integration requirements with existing systems, performance benchmarks and evaluation metrics)
	+ "GenAI-Best-Practices" Skill (approach to prompt engineering, RAG implementation patterns, model evaluation frameworks, security and data privacy guidelines)
* Reducing **time overhead** by eliminating 

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

   • BM25s: 0.013s (3 results)
   • ColBERT: 0.392s (3 results)
   • Fusion: 0.000s (3 candidates)
   • Fetch: 0.001s (3 chunks)
   • Rerank: 0.290s (top 3)
   ✓ Total retrieval: 0.696s

🤖 Generating response (streaming)...

According to Source 1 (**What is Claude Skills** > Key Technical Architecture), the key technical architecture of Claude Skills involves:

* **Progressive Disclosure**: At startup, Claude only loads the name and description of each skill (just metadata), then dynamically loads full details only when relevant to the task. This keeps things token-efficient while maintaining access to deep expertise.

Source: Source 1

⏱️  Response generated in 4.0s

Assistant: According to Source 1 (**What is Claude Skills** > Key Technical Architecture), the key technical architecture of Claude Skills involves:

* **Progressive Disclosure**: At startup, Claude only loads the name and description of each skill (just metadata), then dynamically loads full details only when relevant to

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

Stem Tokens:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

   • BM25s: 0.012s (3 results)
   • ColBERT: 0.089s (3 results)
   • Fusion: 0.000s (3 candidates)
   • Fetch: 0.001s (3 chunks)
   • Rerank: 0.301s (top 3)
   ✓ Total retrieval: 0.402s

🤖 Generating response (streaming)...

I don't have that information in the provided documents.

⏱️  Response generated in 3.1s

Assistant: I don't have that information in the provided documents.


📊 Retrieved Chunks with Similarity Scores (3)

┌─ Chunk 1 ──────────────────────────────────────────────────
│ 🎯 Final Score (ColBERT Rerank): 0.2319
│ 📈 Intermediate Scores:
│    • BM25 (lexical):      0.0000
│    • ColBERT (semantic):  0.2319
│    • RRF (fusion):        0.0323
│ 📍 Section: **What is Claude Skills** > How Claude Skills Can Accelerate Our AI Development Work
│
│ 📄 Text:
│ [Context: **What is Claude Skills**]

## How Claude Skills Can Accelerate Our AI Development Work

**The Current Situation:**


When our AI engineering team uses Claude for development help, we typically spend time explaini