# PDF RAG Implementation with Advanced OCR

This notebook demonstrates a Retrieval-Augmented Generation (RAG) system for PDF documents with improved OCR capabilities that filter out gibberish text and optimize processing speed.

In [None]:
# Cell 1: Imports and Setup
import os
import faiss
import numpy as np
import PyPDF2
import ollama
import cv2
import pytesseract
import nltk
from nltk.corpus import words as nltk_words
from nltk.tokenize import word_tokenize
from IPython.display import display, Markdown
import ipywidgets as widgets
import re
from collections import Counter
import time
import hashlib
import pickle
from concurrent.futures import ThreadPoolExecutor

# Set the path to Tesseract if needed
# pytesseract.pytesseract.tesseract_cmd = r"C:\path\to\tesseract.exe"  # Uncomment and set your path

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('words')

# Create a set of English words for checking
ENGLISH_WORDS = set(w.lower() for w in nltk_words.words())
COMMON_WORDS = set(['the', 'and', 'to', 'of', 'a', 'in', 'that', 'is', 'was', 'for'])

# Cache directory for OCR results
CACHE_DIR = os.path.join(os.getcwd(), "ocr_cache")
os.makedirs(CACHE_DIR, exist_ok=True)


In [None]:
%pip install ipywidgets faiss-cpu PyPDF2 pytesseract opencv-python nltk


Note: you may need to restart the kernel to use updated packages.


In [None]:
# Cell 2: Text Extraction from PDF with OCR Fallback
def extract_text_from_pdfs(uploaded_files):
    text = ""
    for uploaded_file in uploaded_files:
        reader = PyPDF2.PdfReader(uploaded_file)
        for i, page in enumerate(reader.pages):
            page_text = page.extract_text()
            
            # If PDF contains extractable text
            if page_text and len(page_text.strip()) > 100:
                text += f"[Page {i+1}] {page_text}\n"
            # If the page has little or no text, it might be a scanned image
            else:
                try:
                    # For notebooks, we would need a more complex implementation to handle image-based PDFs
                    # This is a simplified version
                    print(f"Page {i+1} has little text, might need OCR")
                except:
                    pass
                    
            text += "\n"
    return text

In [4]:
# Cell 3: Chunking the Text
def chunk_text(text, chunk_size=1000, chunk_overlap=200):
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + chunk_size, len(text))
        chunks.append(text[start:end])
        start += chunk_size - chunk_overlap
    return chunks


In [5]:
# Cell 4: Embedding Text using Ollama (mxbai-embed-large)
def get_embedding(text):
    response = ollama.embeddings(
        model="mxbai-embed-large",
        prompt=f"Represent this sentence for searching relevant passages: {text}"
    )
    return np.array(response["embedding"], dtype='float32')

In [6]:
# Cell 5: Build FAISS Index
def build_faiss_index(chunks):
    vectors = [get_embedding(chunk) for chunk in chunks]
    dim = len(vectors[0])
    index = faiss.IndexFlatL2(dim)
    index.add(np.array(vectors))
    return index, chunks

In [7]:
# Cell 6: Retrieve Context from Query
def retrieve_context(index, chunks, query, k=1):
    query_embedding = get_embedding(query).reshape(1, -1)
    _, indices = index.search(query_embedding, k)
    return "\n".join([chunks[i] for i in indices[0]])

In [8]:
# Cell 7: Ask Mistral (via Ollama)
def ask_mistral(context, question):
    prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
    response = ollama.chat(model="mistral", messages=[{"role": "user", "content": prompt}])
    return response["message"]["content"]

In [None]:
# Cell 8: Improved OCR Performance with Caching
def get_file_hash(file_bytes):
    """Generate a hash for file bytes to use as a cache key"""
    return hashlib.md5(file_bytes).hexdigest()

def cache_result(cache_key, result):
    """Cache the OCR result for future use"""
    try:
        cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
        with open(cache_file, 'wb') as f:
            pickle.dump(result, f)
        return True
    except Exception as e:
        print(f"Error caching result: {str(e)}")
        return False

def get_cached_result(cache_key):
    """Try to get cached OCR result"""
    try:
        cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
    except Exception as e:
        print(f"Error reading cache: {str(e)}")
    return None

def is_gibberish(text, threshold=0.25):
    """
    Check if text appears to be gibberish - optimized version
    """
    if not text or not isinstance(text, str) or len(text.strip()) < 5:
        return True
        
    # Clean the text
    text = text.strip().lower()
    words = re.findall(r'\b[a-z]{2,}\b', text)
    
    if not words:
        return True
    
    # Quick check for real English words (sample a subset for speed)
    sample_size = min(len(words), 20)  # Check at most 20 words for speed
    sample_words = words[:sample_size]
    real_word_count = sum(1 for w in sample_words if w in ENGLISH_WORDS or w in COMMON_WORDS)
    word_ratio = real_word_count / max(1, len(sample_words))
    
    # Calculate overall gibberish score
    gibberish_score = (1 - word_ratio) * 0.7
    
    return gibberish_score > threshold

def clean_ocr_text(text):
    """Clean common OCR artifacts - optimized version"""
    if not text:
        return text
        
    # Most important replacements for speed
    text = re.sub(r'\b[A-Z]{5,}\b', ' ', text)  # Remove all-caps gibberish words
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)  # Remove non-ASCII characters
    text = re.sub(r'[\r\n]+', '\n', text)       # Normalize line breaks
    text = re.sub(r'\s+', ' ', text)            # Normalize whitespace
    
    return text.strip()

In [None]:
# Cell 9: Run RAG with Performance Metrics
# Provide your PDF file names here (they must be in the same folder as the notebook)
pdf_files = ["test file.pdf"]  # 📝 Replace with your file names
pdf_paths = [os.path.join(os.getcwd(), f) for f in pdf_files]

# Extract → Chunk → Embed → Build Index
start_time = time.time()
raw_text = extract_text_from_pdfs(pdf_paths)
chunks = chunk_text(raw_text)
index, chunks = build_faiss_index(chunks)
processing_time = time.time() - start_time
print(f"✅ Processed {len(pdf_files)} PDF(s) and created vector store with {len(chunks)} chunks.")
print(f"⏱️ Processing time: {processing_time:.2f} seconds")

✅ Processed 1 PDF(s) and created vector store with 3 chunks.


In [None]:
# Cell 10: Hardcoded Question and Answer with Timing
from IPython.display import display, Markdown

# 🔽 Replace this with your custom question
question = "who did Emma call?"

if not chunks or not index:
    print("Please process your PDFs first.")
else:
    start_time = time.time()
    context = retrieve_context(index, chunks, question)
    answer = ask_mistral(context, question)
    query_time = time.time() - start_time

    display(Markdown(f"**Question:** {question}"))
    display(Markdown(f"**Answer:** {answer}"))
    print(f"⏱️ Query time: {query_time:.2f} seconds")


**Question:** who did Emma call?

**Answer:**  Emma called her best friend Jake.

In [None]:
# Cell 11: Parallel Processing Demonstration
def process_in_parallel(items, process_func, max_workers=None):
    """Process items in parallel and measure performance"""
    start_time = time.time()
    results = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_func, item) for item in items]
        for future in futures:
            try:
                result = future.result()
                results.append(result)
            except Exception as e:
                print(f"Error in parallel processing: {e}")
    
    elapsed_time = time.time() - start_time
    print(f"Processed {len(items)} items in {elapsed_time:.2f} seconds")
    return results

# Example dummy function to demonstrate parallelization
def dummy_processor(text):
    time.sleep(0.1)  # Simulate processing time
    return len(text)

# Test with and without parallelization
test_items = ["text" + str(i) * 100 for i in range(10)]

print("Sequential processing:")
start = time.time()
sequential_results = [dummy_processor(item) for item in test_items]
print(f"Time taken: {time.time() - start:.2f} seconds")

print("\nParallel processing:")
parallel_results = process_in_parallel(test_items, dummy_processor)
print(f"Results match: {sequential_results == parallel_results}")

In [None]:
# Cell 12: Demo OCR on an image file with performance measurement
# Uncomment to test with your own image

"""
import cv2
import matplotlib.pyplot as plt

# Load an image file
image_path = "sample_image.jpg"  # Replace with your image path
img = cv2.imread(image_path)

# Calculate image hash for potential caching
img_hash = hashlib.md5(img.tobytes()).hexdigest()
cached_text = get_cached_result(img_hash)

if cached_text:
    print("Using cached OCR result")
    text = cached_text
else:
    print("Running OCR...")
    start_time = time.time()
    
    # Convert to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # Apply adaptive threshold
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                                cv2.THRESH_BINARY, 11, 2)
    
    # Extract text with Tesseract
    text = pytesseract.image_to_string(thresh)
    cleaned_text = clean_ocr_text(text)
    
    # Cache the result
    cache_result(img_hash, cleaned_text)
    
    ocr_time = time.time() - start_time
    print(f"OCR completed in {ocr_time:.2f} seconds")
    
    text = cleaned_text

# Check if the text is gibberish
gibberish = is_gibberish(text)

# Display image and text
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(thresh, cmap='gray')
plt.title("Preprocessed Image")
plt.axis('off')
plt.tight_layout()

print(f"Extracted text: {'(GIBBERISH DETECTED)' if gibberish else ''}")
print(text)
"""