In [None]:
%pip uninstall -y tensorflow tensorflow-gpu tf-keras keras flash-attn
%pip install -q torch torchvision
%pip install -q transformers accelerate
%pip install -q qwen-vl-utils pillow requests
%pip install -q pdf2image pymupdf pillow

In [None]:
"""
## Step 1: Import Dependencies for Qwen2.5 Vision-Language Model

This cell imports all required libraries for running the Qwen2.5 model, which is a
vision-language model that processes both text and images. We use AutoModelForVision2Seq
instead of the standard AutoModel because Qwen2.5 is specifically designed for
vision-to-sequence tasks (e.g., image captioning, visual question answering).
"""
import time
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
from PIL import Image
import fitz  # PyMuPDF - Used for PDF processing and image extraction
import os
import json
import traceback
from datetime import datetime
from transformers import AutoModelForVision2Seq, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch.nn.functional as F

# Verify library versions and hardware availability
# This helps ensure compatibility and diagnose potential issues before model loading
print("Libraries loaded successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {__import__('transformers').__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Display GPU information if available
# GPU acceleration is critical for efficient inference with large vision-language models
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

In [None]:
"""
## Step 2: Load Qwen2.5 Vision-Language Model

This cell loads the model and processor. Only run this cell once per session
to avoid redundant loading and memory allocation.
"""

checkpoint_path = "Qwen/Qwen2.5-VL-7B-Instruct"

print("Loading model from checkpoint...")

# Load the vision-language model with optimized settings
# - dtype=torch.bfloat16: Uses BFloat16 precision to reduce memory usage by ~50%
#   while maintaining numerical stability better than FP16. Critical for fitting
#   large models on consumer GPUs.
# - device_map="auto": Automatically distributes model layers across available
#   GPU(s) and CPU memory, enabling efficient use of hardware resources.
# - trust_remote_code=True: Allows execution of custom modeling code from the
model = AutoModelForVision2Seq.from_pretrained(
    checkpoint_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# Load the processor which handles tokenization and image preprocessing
# The processor ensures inputs are formatted correctly for the model's expected input structure
processor = AutoProcessor.from_pretrained(
    checkpoint_path,
    trust_remote_code=True
)

print("Model loaded successfully!")

# Display VRAM usage to monitor memory consumption
# This helps identify potential out-of-memory issues and track resource utilization
print(f"Current VRAM allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

In [None]:
"""
## Step 3: PDF to Images Conversion using PyMuPDF

These functions handle PDF processing and image preprocessing for optimal OCR performance.
PyMuPDF (fitz) is used because it provides faster rendering and better memory efficiency
compared to alternatives like pdf2image, and doesn't require external dependencies like Poppler.
"""

def pdf_to_images(pdf_path: str, dpi: int = 300) -> List[Image.Image]:
    """
    Convert each page of a PDF document into a PIL Image.

    This function uses PyMuPDF's rendering engine to convert PDF pages to raster images.
    Higher DPI values produce better quality but increase memory usage and processing time.
    300 DPI is chosen as default because it provides a good balance between quality and
    performance for most OCR tasks.

    Args:
        pdf_path: Absolute or relative path to the PDF file
        dpi: Dots per inch for rendering. Standard values are:
             - 72: Screen quality (fast, lower quality)
             - 150: Acceptable for basic OCR
             - 300: High quality for accurate OCR (recommended)
             - 600: Very high quality for small text

    Returns:
        List of PIL Images in RGB format, one image per page

    Raises:
        FileNotFoundError: If the PDF file doesn't exist
        fitz.FileDataError: If the file is not a valid PDF
    """
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"PDF file not found: {pdf_path}")

    print(f"Converting PDF to images at {dpi} DPI...")

    doc = fitz.open(pdf_path)
    images = []

    # Calculate zoom factor from desired DPI
    # PyMuPDF uses 72 DPI as base resolution, so we scale relative to that
    zoom = dpi / 72.0
    mat = fitz.Matrix(zoom, zoom)

    try:
        for page_num in range(len(doc)):
            page = doc[page_num]

            # Render page to pixmap (raster image)
            # alpha=False removes transparency channel to save memory and ensure RGB output
            pix = page.get_pixmap(matrix=mat, alpha=False)

            # Convert PyMuPDF pixmap to PIL Image
            # This conversion is necessary because the model processor expects PIL Images
            img = Image.frombytes(
                "RGB",
                [pix.width, pix.height],
                pix.samples
            )
            images.append(img)

            print(f"  Processed page {page_num + 1}/{len(doc)}: {pix.width}x{pix.height}px")

    finally:
        # Ensure document is closed even if an error occurs
        # This prevents memory leaks from unclosed file handles
        doc.close()

    print(f"Successfully converted {len(images)} pages")
    return images


def preprocess_image_for_ocr(
    image: Image.Image,
    max_size: int = 2048
) -> Image.Image:
    """
    Resize images that exceed maximum dimensions to prevent memory issues.

    Large images can cause out-of-memory errors during model inference and don't
    necessarily improve OCR accuracy. This function downscales oversized images while
    maintaining aspect ratio. LANCZOS resampling is used because it provides the best
    quality for downscaling, preserving text clarity better than other methods.

    Args:
        image: Input PIL Image in any mode
        max_size: Maximum allowed dimension (width or height) in pixels.
                  2048 is chosen as a reasonable upper bound that balances quality
                  with memory constraints for most GPUs (typically uses ~4-6GB VRAM)

    Returns:
        Preprocessed PIL Image, resized if necessary
    """
    width, height = image.size

    # Only resize if image exceeds maximum dimension
    # This avoids unnecessary processing and potential quality loss for smaller images
    if max(width, height) > max_size:
        # Calculate scale factor to fit within max_size while preserving aspect ratio
        scale = max_size / max(width, height)
        new_width = int(width * scale)
        new_height = int(height * scale)

        # LANCZOS provides highest quality downsampling, critical for preserving text legibility
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
        print(f"    Image resized: {width}x{height} -> {new_width}x{new_height}")

    return image


print("PDF processing functions defined successfully!")

In [None]:
"""
## Step 4: OCR Generation with Retry Logic and Confidence Scoring

This cell implements a robust OCR pipeline with multiple inference attempts per page
and confidence scoring to assess output quality. The selection strategy is optimized
for medical pathology reports where accuracy is critical.
"""

# Generation hyperparameters optimized for OCR tasks
TEMPERATURE_SCHEDULE = [0.1, 0.2, 0.3]
TOP_P_THRESHOLD = 0.95
REPETITION_PENALTY = 1.1

# Medical OCR quality thresholds
QUALITY_SCORE_THRESHOLD = 0.5  # Minimum acceptable composite quality score
PERPLEXITY_THRESHOLD = 20.0     # Maximum acceptable perplexity
MIN_PROB_THRESHOLD = 0.05       # Minimum acceptable token probability


def calculate_confidence_scores(scores, generated_ids) -> Dict:
    """
    Calculate various confidence metrics from model generation scores.

    Confidence scores help assess the reliability of OCR output. Lower confidence
    may indicate poor image quality, unusual fonts, or model uncertainty.

    Args:
        scores: Tuple of tensors containing logits for each generated token
        generated_ids: The generated token IDs (1D tensor)

    Returns:
        Dictionary with confidence metrics:
            - mean_probability: Average probability across all tokens (0-1)
            - mean_log_probability: Average log probability (more numerically stable)
            - perplexity: Model's uncertainty (lower is better)
            - min_probability: Lowest token probability (identifies uncertain tokens)
    """
    if not scores or len(scores) == 0:
        return {
            'mean_probability': None,
            'mean_log_probability': None,
            'perplexity': None,
            'min_probability': None
        }

    # Convert logits to probabilities for each token position
    # Logits are raw model outputs; softmax converts them to probability distributions
    token_probs = []
    for i, logits in enumerate(scores):
        # Apply softmax to get probability distribution over vocabulary
        probs = F.softmax(logits[0], dim=-1)

        # Get probability of the actual generated token
        token_id = generated_ids[i].item()
        token_prob = probs[token_id].item()
        token_probs.append(token_prob)

    token_probs = np.array(token_probs)

    # Calculate various confidence metrics
    # Mean probability: Simple average, intuitive but can be skewed by very low values
    mean_prob = float(np.mean(token_probs))

    # Log probability: More stable for very small probabilities, commonly used in NLP
    log_probs = np.log(token_probs + 1e-10)  # Add epsilon to avoid log(0)
    mean_log_prob = float(np.mean(log_probs))

    # Perplexity: Exponential of negative mean log probability
    # Intuition: "How surprised is the model?" Lower perplexity = higher confidence
    # Typical range: 1.0 (perfect) to 100+ (very uncertain)
    perplexity = float(np.exp(-mean_log_prob))

    # Minimum probability: Identifies the least confident token
    # Useful for spotting specific problem areas in the output
    min_prob = float(np.min(token_probs))

    return {
        'mean_probability': mean_prob,
        'mean_log_probability': mean_log_prob,
        'perplexity': perplexity,
        'min_probability': min_prob
    }


def calculate_composite_score(response: Dict) -> float:
    """
    Calculate composite quality score optimized for medical document OCR.
    
    For medical pathology reports, accuracy is paramount. This scoring function
    heavily weights quality (80%) over completeness (20%) to avoid hallucinations.
    These weights were chosen via trial-and-error.
    
    Args:
        response: Dictionary containing OCR attempt results with confidence metrics
        
    Returns:
        Float score from 0-1 where higher indicates better quality.
        Returns 0 if confidence data is unavailable.
    """
    chars = response['chars']
    confidence = response['confidence']
    temperature = response['temperature']
    
    # Require confidence data

    if confidence.get('perplexity') is None:
        return 0.0
    
    perplexity = confidence['perplexity']
    min_prob = confidence.get('min_probability', 0)
    
    # Normalize character count (0-1 scale, assuming 2000 chars is a "full page")
    normalized_length = min(chars / 2000.0, 1.0)
    
    # Convert perplexity to quality score (0-1 scale)
    # Lower perplexity = higher quality
    # Typical range: 1.0 (perfect) → 1.0, 10.0 → 0.52, 100 → 0.18
    quality_score = 1.0 / (1.0 + np.log(max(perplexity, 1.0)))
    
    # Minimum token probability penalty
    # Penalize outputs with very uncertain tokens (potential hallucinations)
    min_prob_penalty = 1.0 if min_prob > MIN_PROB_THRESHOLD else 0.7
    
    # Temperature bias: prefer conservative (low temperature) outputs
    # we favor certainty over exploration
    temperature_penalty = 1.0 if temperature <= 0.1 else 0.95
    
    # Final composite score
    # 80% quality (perplexity-based), 20% completeness (length-based)
    # Applied penalties for uncertain tokens and higher temperatures
    final_score = (
        (0.80 * quality_score) + 
        (0.20 * normalized_length)
    ) * min_prob_penalty * temperature_penalty
    
    return final_score


def select_best_ocr(all_responses: List[Dict]) -> Dict:
    """
    Select best OCR attempt for documents with quality gating.
    
    This function implements a quality-first selection strategy optimized for
    reports where accuracy is critical. It adds quality warning
    flags when confidence falls below acceptable thresholds.
    
    Args:
        all_responses: List of OCR attempt results
        
    Returns:
        Dictionary with best response and quality warning flags
    """
    if not all_responses:
        return None
    
    # Calculate scores for all responses
    scored_responses = []
    for response in all_responses:
        score = calculate_composite_score(response)
        scored_responses.append((score, response))
    
    # Select best by composite score
    best_score, best_response = max(scored_responses, key=lambda x: x[0])
    
    # Add quality metadata
    best_response['composite_score'] = best_score
    
    # Quality gate: flag if below thresholds
    perplexity = best_response['confidence'].get('perplexity', float('inf'))
    min_prob = best_response['confidence'].get('min_probability', 0)
    
    quality_warnings = []
    
    if best_score < QUALITY_SCORE_THRESHOLD:
        quality_warnings.append(
            f"Low composite quality score: {best_score:.3f} (threshold: {QUALITY_SCORE_THRESHOLD})"
        )
    
    if perplexity > PERPLEXITY_THRESHOLD:
        quality_warnings.append(
            f"High perplexity (uncertainty): {perplexity:.2f} (threshold: {PERPLEXITY_THRESHOLD})"
        )
    
    if min_prob < MIN_PROB_THRESHOLD:
        quality_warnings.append(
            f"Very uncertain tokens detected: min_prob={min_prob:.4f} (threshold: {MIN_PROB_THRESHOLD})"
        )
    
    # Check consistency across attempts
    if len(all_responses) > 1:
        char_counts = [r['chars'] for r in all_responses]
        cv = np.std(char_counts) / np.mean(char_counts) if np.mean(char_counts) > 0 else 0
        if cv > 0.3:  # >30% variation
            quality_warnings.append(
                f"High variance between attempts: {cv:.1%} - results inconsistent"
            )
    
    best_response['quality_warning'] = len(quality_warnings) > 0
    best_response['quality_warnings'] = quality_warnings
    
    if quality_warnings:
        best_response['warning_message'] = (
            "QUALITY ALERT - Manual review recommended:\n" + 
            "\n".join(f"  - {w}" for w in quality_warnings)
        )
    
    return best_response


def ocr_with_retry(
    image: Image.Image,
    page_num: int = 1,
    num_attempts: int = 3,
    max_new_tokens: int = 2048,
    use_cot: bool = True
) -> Dict:
    """
    Perform OCR on a single image with multiple attempts, temperature variations,
    and confidence scoring optimized for medical documents.

    Args:
        image: PIL Image to perform OCR on
        page_num: Page number for logging and result tracking
        num_attempts: Number of OCR attempts with different temperatures
        max_new_tokens: Maximum tokens to generate
        use_cot: Whether to use Chain-of-Thought prompting

    Returns:
        Dictionary containing page results with confidence scores and quality warnings
    """

    if use_cot:
        prompt = """Extract all text from this medical pathology report image with perfect accuracy. This is a clinical document where precision is critical.

Work systematically through these steps:

1. Identify the document structure (patient info, diagnostic sections, lab values, tables)
2. Read each section carefully from top to bottom, left to right
3. Pay special attention to:
   - Patient identifiers and demographics
   - Diagnosis codes and staging information
   - Numerical values (measurements, lab results, percentages)
   - Dates and timestamps
   - Physician names and signatures
4. Preserve all formatting, line breaks, and special characters exactly
5. Double-check all numerical values and medical terminology
6. Output the complete text transcription exactly as shown

Provide the full text transcription:"""
    else:
        prompt = "Read and transcribe all text from this medical document image exactly as shown, preserving all formatting and structure."

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    ).to(model.device)

    all_responses = []

    print(f"\nPage {page_num} - Running {num_attempts} OCR attempts...")

    for i in range(num_attempts):
        temp = TEMPERATURE_SCHEDULE[i % len(TEMPERATURE_SCHEDULE)]
        print(f"  Attempt [{i+1}/{num_attempts}] (temperature={temp})...", end=" ")

        start = time.time()

        try:
            with torch.inference_mode():
                # Enable score output to get token probabilities
                # return_dict_in_generate=True provides structured output with scores
                # output_scores=True includes logits for each generated token
                generation_output = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temp,
                    top_p=TOP_P_THRESHOLD,
                    do_sample=temp > 0,
                    repetition_penalty=REPETITION_PENALTY,
                    return_dict_in_generate=True,
                    output_scores=True
                )

            # Extract generated token IDs
            generated_ids = generation_output.sequences

            # Trim input tokens from output
            generated_ids_trimmed = [
                out_ids[len(in_ids):]
                for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]

            # Decode tokens to text
            output_text = processor.batch_decode(
                generated_ids_trimmed,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )[0]

            elapsed = time.time() - start
            char_count = len(output_text)
            word_count = len(output_text.split())

            # Calculate confidence scores from generation scores
            confidence = calculate_confidence_scores(
                generation_output.scores,
                generated_ids_trimmed[0]
            )

            all_responses.append({
                'attempt': i + 1,
                'text': output_text,
                'temperature': temp,
                'time': elapsed,
                'chars': char_count,
                'words': word_count,
                'confidence': confidence
            })

            # Display confidence in output
            perp_display = f"perplexity={confidence['perplexity']:.2f}" if confidence['perplexity'] else "perplexity=N/A"
            print(f"Success - {char_count} chars, {perp_display} ({elapsed:.1f}s)")

            torch.cuda.empty_cache()

        except Exception as e:
            print(f"Failed - {str(e)}")
            # Print more detailed error for debugging
            print(f"  Error details: {traceback.format_exc()}")
            continue

    if not all_responses:
        return {"error": f"All OCR attempts failed for page {page_num}"}

    # Select best response using  quality scoring
    best = select_best_ocr(all_responses)

    return {
        'page_num': page_num,
        'best_response': best,
        'all_responses': all_responses,
        'total_attempts': len(all_responses)
    }


def ocr_pdf(
    pdf_path: str,
    dpi: int = 300,
    attempts_per_page: int = 3,
    use_cot: bool = True,
    max_pages: Optional[int] = None
) -> Dict:
    """
    Perform OCR on an entire PDF document with retry logic and confidence scoring per page.

    Args:
        pdf_path: Path to the PDF file
        dpi: Resolution for PDF rendering
        attempts_per_page: Number of OCR attempts per page
        use_cot: Whether to use Chain-of-Thought prompting
        max_pages: Optional limit on pages to process

    Returns:
        Dictionary with results including per-page confidence scores and quality warnings
    """
    print("=" * 80)
    print("PDF OCR WITH QUALITY GATING")
    print("=" * 80)

    images = pdf_to_images(pdf_path, dpi=dpi)

    if max_pages:
        images = images[:max_pages]
        print(f"Processing first {max_pages} pages only (limit applied)")

    total_pages = len(images)
    print(f"Total pages to process: {total_pages}\n")

    all_results = []
    pages_with_warnings = 0
    total_start = time.time()

    for i, image in enumerate(images, 1):
        processed_img = preprocess_image_for_ocr(image)

        result = ocr_with_retry(
            image=processed_img,
            page_num=i,
            num_attempts=attempts_per_page,
            max_new_tokens=2048,
            use_cot=use_cot
        )

        all_results.append(result)

        if 'best_response' in result:
            best = result['best_response']
            conf = best['confidence']
            perp_str = f"{conf['perplexity']:.2f}" if conf['perplexity'] else "N/A"
            score_str = f"{best.get('composite_score', 0):.3f}"
            
            print(f"  Page {i}/{total_pages} completed: "
                  f"{best['chars']} chars, perplexity={perp_str}, quality={score_str}")
            
            # Display quality warnings if present
            if best.get('quality_warning'):
                pages_with_warnings += 1
                print(f"  {best['warning_message']}")
            print()
        else:
            print(f"  Page {i}/{total_pages} failed: {result.get('error', 'Unknown error')}\n")

    total_time = time.time() - total_start

    full_text = "\n\n" + "\n\n".join([
        f"{'=' * 80}\nPAGE {r['page_num']}\n{'=' * 80}\n{r['best_response']['text']}"
        for r in all_results if 'best_response' in r
    ])

    total_chars = sum(
        r['best_response']['chars']
        for r in all_results
        if 'best_response' in r
    )

    # Calculate average perplexity across all pages
    avg_perplexity = np.mean([
        r['best_response']['confidence']['perplexity']
        for r in all_results
        if 'best_response' in r and r['best_response']['confidence']['perplexity'] is not None
    ]) if all_results else None

    # Calculate average quality score
    avg_quality = np.mean([
        r['best_response'].get('composite_score', 0)
        for r in all_results
        if 'best_response' in r
    ]) if all_results else None

    print("\n" + "=" * 80)
    print("PDF OCR COMPLETE")
    print("=" * 80)
    print(f"Pages processed: {total_pages}")
    print(f"Total characters: {total_chars:,}")
    print(f"Average perplexity: {avg_perplexity:.2f}" if avg_perplexity else "Average perplexity: N/A")
    print(f"Average quality score: {avg_quality:.3f}" if avg_quality else "Average quality score: N/A")
    print(f"Pages with quality warnings: {pages_with_warnings}/{total_pages}")
    print(f"Total time: {total_time:.1f}s (average: {total_time/total_pages:.1f}s per page)")
    
    if pages_with_warnings > 0:
        print(f"\nWARNING: {pages_with_warnings} page(s) flagged for manual review")
    print("=" * 80)

    return {
        'pages': all_results,
        'full_text': full_text,
        'total_pages': total_pages,
        'total_chars': total_chars,
        'total_time': total_time,
        'avg_perplexity': avg_perplexity,
        'avg_quality_score': avg_quality,
        'pages_with_warnings': pages_with_warnings
    }


print("OCR functions with quality gating defined successfully!")

In [None]:
"""
## Step 5: Result Saving Functions

This cell implements functions to save OCR results to text and JSON formats.
Text files contain the extracted content, while JSON files preserve all metadata
including confidence scores and quality warnings for detailed analysis.
"""


def save_results(
    results: Dict,
    filename: str,
    include_metadata: bool = True
) -> None:
    """
    Save OCR results to a text file with optional metadata.

    Args:
        results: Dictionary containing OCR results from ocr_pdf()
        filename: Output file path (absolute or relative)
        include_metadata: If True, prepend metadata header to output

    Raises:
        IOError: If file cannot be written
    """
    try:
        with open(filename, 'w', encoding='utf-8') as f:
            if include_metadata:
                # Write metadata header
                f.write("=" * 80 + "\n")
                f.write("OCR METADATA\n")
                f.write("=" * 80 + "\n")
                f.write(f"Processing Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                f.write(f"Total Pages: {results.get('total_pages', 0)}\n")
                f.write(f"Total Characters: {results.get('total_chars', 0):,}\n")
                
                avg_perp = results.get('avg_perplexity')
                if avg_perp:
                    f.write(f"Average Perplexity: {avg_perp:.2f}\n")
                
                avg_quality = results.get('avg_quality_score')
                if avg_quality:
                    f.write(f"Average Quality Score: {avg_quality:.3f}\n")
                
                warnings = results.get('pages_with_warnings', 0)
                if warnings > 0:
                    f.write(f"\nQuality Warnings: {warnings} page(s) flagged for manual review\n")
                
                f.write(f"Processing Time: {results.get('total_time', 0):.1f}s\n")
                f.write("=" * 80 + "\n\n")
            
            # Write full text content
            f.write(results.get('full_text', ''))
            
        print(f"Results saved to: {filename}")
        
    except IOError as e:
        print(f"Error saving results to {filename}: {str(e)}")
        raise


def save_detailed_results(
    results: Dict,
    filename: str
) -> None:
    """
    Save comprehensive OCR results to JSON format including all attempts and metrics.

    This function preserves all OCR attempts, confidence scores, quality warnings,
    and metadata for detailed analysis and quality assurance workflows.

    Args:
        results: Dictionary containing OCR results from ocr_pdf()
        filename: Output JSON file path (absolute or relative)

    Raises:
        IOError: If file cannot be written
    """
    try:
        # Prepare serializable output
        output = {
            'metadata': {
                'processing_date': datetime.now().isoformat(),
                'total_pages': results.get('total_pages', 0),
                'total_chars': results.get('total_chars', 0),
                'total_time_seconds': results.get('total_time', 0),
                'avg_perplexity': results.get('avg_perplexity'),
                'avg_quality_score': results.get('avg_quality_score'),
                'pages_with_warnings': results.get('pages_with_warnings', 0)
            },
            'pages': []
        }
        
        # Add per-page details
        for page_result in results.get('pages', []):
            if 'error' in page_result:
                output['pages'].append({
                    'page_num': page_result.get('page_num'),
                    'error': page_result['error']
                })
            else:
                page_data = {
                    'page_num': page_result['page_num'],
                    'total_attempts': page_result['total_attempts'],
                    'best_response': {
                        'attempt': page_result['best_response']['attempt'],
                        'temperature': page_result['best_response']['temperature'],
                        'chars': page_result['best_response']['chars'],
                        'words': page_result['best_response']['words'],
                        'time_seconds': page_result['best_response']['time'],
                        'composite_score': page_result['best_response'].get('composite_score'),
                        'confidence': page_result['best_response']['confidence'],
                        'quality_warning': page_result['best_response'].get('quality_warning', False),
                        'quality_warnings': page_result['best_response'].get('quality_warnings', []),
                        'text': page_result['best_response']['text']
                    },
                    'all_attempts': [
                        {
                            'attempt': resp['attempt'],
                            'temperature': resp['temperature'],
                            'chars': resp['chars'],
                            'words': resp['words'],
                            'time_seconds': resp['time'],
                            'confidence': resp['confidence']
                        }
                        for resp in page_result['all_responses']
                    ]
                }
                output['pages'].append(page_data)
        
        # Write JSON with pretty formatting
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(output, f, indent=2, ensure_ascii=False)
        
        print(f"Detailed results saved to: {filename}")
        
    except IOError as e:
        print(f"Error saving detailed results to {filename}: {str(e)}")
        raise


print("Result saving functions defined successfully!")

In [None]:
"""
## Step 6: Batch Process PDFs in Directory Structure

This cell implements recursive directory traversal to process all PDFs found in
a folder hierarchy. Results are saved alongside source files to maintain organization.
"""


def find_pdf_files(root_directory: str, skip_processed: bool = True) -> List[Tuple[str, str]]:
    """
    Recursively find all PDF files in a directory and its subdirectories.

    This function traverses the entire directory tree to locate PDFs while optionally
    skipping files that have already been processed (i.e., have corresponding .txt files).
    This prevents redundant processing in subsequent runs.

    Args:
        root_directory: Path to the root directory to search
        skip_processed: If True, skip PDFs that already have corresponding .txt output files.
                       This is useful for resuming interrupted batch jobs.

    Returns:
        List of tuples: (pdf_path, output_txt_path) for each PDF to process
    """
    pdf_files = []

    # Validate root directory exists
    if not os.path.exists(root_directory):
        raise FileNotFoundError(f"Directory not found: {root_directory}")

    print(f"Scanning directory: {root_directory}")

    # os.walk recursively yields (dirpath, dirnames, filenames) for each directory
    # This is more efficient than recursive function calls for deep hierarchies
    for dirpath, dirnames, filenames in os.walk(root_directory):
        for filename in filenames:
            # Case-insensitive PDF detection to handle .PDF, .pdf, .Pdf, etc.
            if filename.lower().endswith('.pdf'):
                pdf_path = os.path.join(dirpath, filename)

                # Generate output filename by replacing .pdf extension with .txt
                # This keeps the output in the same directory as the source
                output_filename = os.path.splitext(filename)[0] + '_ocr.txt'
                output_path = os.path.join(dirpath, output_filename)

                # Skip if already processed (unless user wants to reprocess)
                if skip_processed and os.path.exists(output_path):
                    print(f"  Skipping (already processed): {pdf_path}")
                    continue

                pdf_files.append((pdf_path, output_path))

    print(f"Found {len(pdf_files)} PDF(s) to process")
    return pdf_files


def process_pdf_batch(
    root_directory: str,
    dpi: int = 300,
    attempts_per_page: int = 3,
    use_cot: bool = True,
    max_pages: Optional[int] = None,
    skip_processed: bool = True,
    save_detailed: bool = False
) -> Dict:
    """
    Process all PDFs found in a directory tree with OCR.

    This function orchestrates batch OCR processing across multiple files. It handles
    errors gracefully so that one failed PDF doesn't stop the entire batch. Progress
    is tracked and reported to help monitor long-running jobs.

    Args:
        root_directory: Root directory to search for PDFs
        dpi: Resolution for PDF rendering
        attempts_per_page: Number of OCR attempts per page
        use_cot: Whether to use Chain-of-Thought prompting
        max_pages: Optional limit on pages per PDF (useful for testing)
        skip_processed: Skip PDFs that already have output files
        save_detailed: If True, also save detailed JSON output with all attempts

    Returns:
        Dictionary with batch processing statistics
    """

    print("=" * 80)
    print("BATCH PDF OCR PROCESSING")
    print("=" * 80)

    # Find all PDFs to process
    pdf_files = find_pdf_files(root_directory, skip_processed=skip_processed)

    if not pdf_files:
        print("\nNo PDFs found to process.")
        return {
            'total_files': 0,
            'successful': 0,
            'failed': 0,
            'skipped': 0
        }

    # Track batch statistics
    total_files = len(pdf_files)
    successful = 0
    failed = 0
    failed_files = []
    batch_start = time.time()

    print(f"\nProcessing {total_files} PDF file(s)...\n")

    # Process each PDF individually
    # Using enumerate for progress tracking
    for idx, (pdf_path, output_path) in enumerate(pdf_files, 1):
        print("=" * 80)
        print(f"FILE {idx}/{total_files}: {os.path.basename(pdf_path)}")
        print(f"Location: {os.path.dirname(pdf_path)}")
        print("=" * 80)

        try:
            # Run OCR on the PDF
            # Each file is processed independently to isolate errors
            results = ocr_pdf(
                pdf_path=pdf_path,
                dpi=dpi,
                attempts_per_page=attempts_per_page,
                use_cot=use_cot,
                max_pages=max_pages
            )

            # Save text output to the same directory as source PDF
            # This maintains the organizational structure of the input
            save_results(
                results=results,
                filename=output_path,
                include_metadata=True
            )

            # Optionally save detailed JSON output with all metadata
            # Useful for quality analysis or when you need access to all attempts
            if save_detailed:
                detailed_path = output_path.replace('_ocr.txt', '_ocr_detailed.json')
                save_detailed_results(results, filename=detailed_path)

            successful += 1
            print(f"\nFile {idx}/{total_files} completed successfully")

        except Exception as e:
            # Log error but continue processing remaining files
            # This ensures one problematic PDF doesn't halt the entire batch
            failed += 1
            failed_files.append({
                'file': pdf_path,
                'error': str(e)
            })

            print(f"\nError processing {pdf_path}:")
            print(f"  {str(e)}")
            print("\nFull traceback:")
            print(traceback.format_exc())
            print("\nContinuing with next file...")

        # Clear GPU memory between files to prevent accumulation
        # This is critical for processing large batches without memory errors
        torch.cuda.empty_cache()
        print()

    batch_time = time.time() - batch_start

    # Print final batch summary
    print("\n" + "=" * 80)
    print("BATCH PROCESSING COMPLETE")
    print("=" * 80)
    print(f"Total files: {total_files}")
    print(f"Successful: {successful}")
    print(f"Failed: {failed}")
    print(f"Total time: {batch_time:.1f}s ({batch_time/60:.1f}m)")
    
    if failed_files:
        print(f"\nFailed files:")
        for ff in failed_files:
            print(f"  - {ff['file']}: {ff['error']}")
    
    print("=" * 80)

    return {
        'total_files': total_files,
        'successful': successful,
        'failed': failed,
        'failed_files': failed_files,
        'batch_time_seconds': batch_time
    }


print("Batch processing functions defined successfully!")

In [None]:
"""
## Step 7: Execute Batch Processing

Run this cell to process all PDFs in the specified directory.
Adjust the parameters below based on your needs.
"""

# Configure batch processing parameters
results = process_pdf_batch(
    root_directory='reports',          # Directory containing PDF files
    dpi=300,                            # Image resolution (300 recommended)
    attempts_per_page=3,                # Number of retry attempts per page
    skip_processed=True,                # Skip files that already have output
    use_cot=True,                       # Use medical-optimized CoT prompting
    max_pages=None,                     # Limit pages per PDF (None = all pages)
    save_detailed=False                 # Save detailed JSON with all attempts
)

# Display batch statistics
print(f"\n\nBatch Processing Summary:")
print(f"  Total files processed: {results['successful']}/{results['total_files']}")
print(f"  Processing time: {results['batch_time_seconds']/60:.1f} minutes")