In [1]:
import os
import warnings
from pathlib import Path
from typing import List, Optional
import argparse

from PIL import Image
import torch


In [2]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [3]:
INPUT_FOLDER = os.path.join("Data", "JPG")
# INPUT_FOLDER = os.path.join("Data", "test")
OUTPUT_FOLDER = "ocr_output_ref"

In [4]:
def check_cuda_compatibility():
    """
    Check CUDA compatibility and provide diagnostic information
    """
    print("=== CUDA Compatibility Check ===")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"CUDA version (PyTorch): {torch.version.cuda}")
        print(f"cuDNN version: {torch.backends.cudnn.version()}")
        print(f"GPU count: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
        # Check memory
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i} memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.1f} GB")
    
    print("================================\n")

def fix_cuda_issues():
    """
    Apply various CUDA fixes
    """
    print("Applying CUDA fixes...")
    
    # Fix 1: Set CUDA environment variables
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ['TORCH_USE_CUDA_DSA'] = '1'
    
    # Fix 2: Set cuDNN deterministic mode
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Fix 3: Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("CUDA fixes applied.\n")
    
def test_cuda_operations():
    """
    Test basic CUDA operations to identify the problematic layer
    """
    print("Testing CUDA operations...")
    
    device = torch.device('cuda:0')
    
    # Test 1: Basic tensor operations
    try:
        x = torch.randn(10, 10).to(device)
        y = torch.matmul(x, x)
        print("✅ Basic tensor operations work")
    except Exception as e:
        print(f"❌ Basic tensor operations failed: {e}")
        return False
    
    # Test 2: Convolution operations (likely culprit)
    try:
        conv = torch.nn.Conv2d(3, 64, kernel_size=3).to(device)
        x = torch.randn(1, 3, 224, 224).to(device)
        y = conv(x)
        print("✅ Convolution operations work")
    except Exception as e:
        print(f"❌ Convolution operations failed: {e}")
        print("This is likely the source of the problem!")
        return False
    
    # Test 3: Transformer operations
    try:
        from transformers import AutoModel
        model = AutoModel.from_pretrained("distilbert-base-uncased").to(device)
        print("✅ Transformer operations work")
    except Exception as e:
        print(f"❌ Transformer operations failed: {e}")
        return False
    
    return True

In [5]:
def setup_trocr_model():
    """
    Step 3: Load the TrOCR model and processor
    Using the handwritten text model variant
    """
    print("Loading TrOCR model for handwritten text...")
    
    # Load processor and model for handwritten text
    processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten",)
    model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
    
    # Use GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print(f"Model loaded on device: {device}")
    return processor, model, device


In [6]:

def preprocess_image(image_path):
    """
    Step 4: Load and preprocess the image
    """
    print(f"Loading image: {image_path}")
    
    # Load image
    image = Image.open(image_path).rotate(180)
    
    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    print(f"Image size: {image.size}")
    print(f"Image mode: {image.mode}")
    
    return image


In [7]:

def perform_ocr(image, processor, model, device):
    """
    Step 5: Perform OCR on the image
    """
    print("Processing image with TrOCR...")
    
    # Process the image
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Generate text
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)
    
    # Decode the generated text
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return generated_text


In [8]:
def get_image_files(folder_path: str) -> List[str]:
    """Get all image files from folder"""
    folder = Path(folder_path)
    if not folder.exists():
        print(f"Error: Folder '{folder_path}' does not exist")
        return []
    
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.tiff', '*.bmp']:
        image_files.extend(folder.glob(ext))
        image_files.extend(folder.glob(ext.upper()))
    
    return sorted([str(f) for f in image_files])

In [9]:

def process_multiple_lines(image, processor, model, device, crop_height=100):
    """
    Step 6: Process image line by line for better results
    This is useful for handwritten text with multiple lines
    """
    print("Processing image line by line...")
    
    width, height = image.size
    results = []
    
    # Process in horizontal strips
    y_start = 0
    line_num = 1
    
    while y_start < height:
        y_end = min(y_start + crop_height, height)
        
        # Crop the line
        line_image = image.crop((0, y_start, width, y_end))
        
        # Skip if the crop is too small
        if line_image.size[1] < 20:
            break
            
        print(f"Processing line {line_num}: y={y_start}-{y_end}")
        
        # Perform OCR on the line
        try:
            line_text = perform_ocr(line_image, processor, model, device)
            if line_text.strip():  # Only add non-empty results
                results.append(f"Line {line_num}: {line_text}")
            line_num += 1
        except Exception as e:
            print(f"Error processing line {line_num}: {e}")
        
        y_start += crop_height
    
    return results


In [10]:

def main():
    """
    Main function to run the complete OCR pipeline
    """
    print("=== TrOCR Handwritten Text Recognition ===")
    
    # Step 1: Setup model
    processor, model, device = setup_trocr_model()
    
    # Step 2: Process your image
    image_path = "your_handwritten_image.jpg"  # Replace with your image path
    
    try:
        # Load image
        image = preprocess_image(image_path)
        
        # Method 1: Process entire image at once
        print("\n--- Method 1: Full Image Processing ---")
        full_text = perform_ocr(image, processor, model, device)
        print("Extracted text (full image):")
        print(full_text)
        
        # Method 2: Process line by line (often better for handwritten text)
        print("\n--- Method 2: Line-by-Line Processing ---")
        line_results = process_multiple_lines(image, processor, model, device)
        
        print("Extracted text (line by line):")
        for line in line_results:
            print(line)
            
        # Step 7: Save results
        with open("ocr_results.txt", "w", encoding="utf-8") as f:
            f.write("=== Full Image OCR ===\n")
            f.write(full_text + "\n\n")
            f.write("=== Line-by-Line OCR ===\n")
            for line in line_results:
                f.write(line + "\n")
        
        print("\nResults saved to 'ocr_results.txt'")
        
    except FileNotFoundError:
        print(f"Error: Could not find image file '{image_path}'")
        print("Please update the image_path variable with your actual image file path")
    except Exception as e:
        print(f"Error during processing: {e}")

# Additional utility functions

def batch_process_images(input_folder, output_dir="ocr_output_ref"):
    """
    Process multiple images in batch
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    image_files = get_image_files(input_folder)
    
    if not image_files:
            print(f"No image files found in '{input_folder}'")
            return
    print(f"\nFound {len(image_files)} image files")
    
    processor, model, device = setup_trocr_model()
    
    for i, image_path in enumerate(image_files, 1):
        print(f"\nProcessing image {i+1}/{len(image_files)}: {image_path}")
        
        try:
            image = preprocess_image(image_path)
            line_results = process_multiple_lines(image, processor, model, device)
            
            # Save individual result
            image_filename = os.path.basename(image_path)
            base_name, _ = os.path.splitext(image_filename)
            output_file = os.path.join(output_dir, f"{base_name}.txt")
            with open(output_file, 'w', encoding='utf-8') as f:
                f.write(f"=== TrOCR Result for {i} ===\n\n")
                for line in line_results:
                    f.write(line + "\n")
                f.write(f"\n\n=== End of {i} ===\n")
                
            print(f"Saved result to: {output_file}")
            
        except Exception as e:
            print(f"Error processing {image_path}: {e}")

def improve_ocr_accuracy_tips():
    """
    Tips for improving OCR accuracy with handwritten text
    """
    tips = """
    === Tips for Better OCR Results ===
    
    1. Image Quality:
       - Use high resolution images (300+ DPI)
       - Ensure good contrast between text and background
       - Minimize shadows and glare
    
    2. Preprocessing:
       - Consider image enhancement (contrast, brightness)
       - Rotate image if text is skewed
       - Crop to remove unnecessary borders
    
    3. TrOCR Model Variants:
       - microsoft/trocr-base-handwritten (general handwriting)
       - microsoft/trocr-large-handwritten (better accuracy, slower)
       - Fine-tune on your specific handwriting style if needed
    
    4. Processing Strategy:
       - Process line by line for multi-line text
       - Adjust crop_height based on your text size
       - Consider word-level processing for very challenging text
    
    5. Post-processing:
       - Use spell-check or language models to correct obvious errors
       - Manual review is often necessary for historical documents
    """
    print(tips)


In [11]:
batch_process_images(INPUT_FOLDER,OUTPUT_FOLDER)


Found 46 image files
Loading TrOCR model for handwritten text...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded on device: cuda

Processing image 2/46: Data/JPG/Beatrice_Moxon_diary_1885_0000.jpg
Loading image: Data/JPG/Beatrice_Moxon_diary_1885_0000.jpg
Image size: (4500, 3127)
Image mode: RGB
Processing image line by line...
Processing line 1: y=0-100
Processing image with TrOCR...
Processing line 2: y=100-200
Processing image with TrOCR...
Processing line 3: y=200-300
Processing image with TrOCR...
Processing line 4: y=300-400
Processing image with TrOCR...
Processing line 5: y=400-500
Processing image with TrOCR...
Processing line 6: y=500-600
Processing image with TrOCR...
Processing line 7: y=600-700
Processing image with TrOCR...
Processing line 8: y=700-800
Processing image with TrOCR...
Processing line 9: y=800-900
Processing image with TrOCR...
Processing line 10: y=900-1000
Processing image with TrOCR...
Processing line 11: y=1000-1100
Processing image with TrOCR...
Processing line 12: y=1100-1200
Processing image with TrOCR...
Processing line 13: y=1200-1300
Processing i