# OCR-Devanagari-CRNN — Dataset Analysis & Training Pipeline

## Overview
This notebook implements a complete pipeline for building an Optical Character Recognition (OCR) system specifically trained for Devanagari script using a Convolutional Recurrent Neural Network (CRNN) architecture combined with LSTM layers and CTC (Connectionist Temporal Classification) loss.

## Key Components:
1. **Data Source**: HuggingFace dataset "Sakonii/nepalitext-language-model-dataset" containing real Nepali text
2. **Synthetic Data Generation**: Uses HarfBuzz for proper Devanagari shaping and rendering with augmentations
3. **Model Architecture**: CRNN (CNN feature extractor + Bidirectional LSTM + CTC decoder)
4. **Training Ready**: Configuration and pipeline setup for training the OCR model

## Why This Approach?
- **HarfBuzz**: Ensures proper rendering of complex Devanagari ligatures and diacritics
- **Synthetic Data**: Allows controlled generation of training samples with various fonts and augmentations
- **CRNN+LSTM**: Powerful combination for sequence-to-sequence learning in OCR tasks
- **CTC Loss**: Optimal for sequence alignment when character positions are unknown

In [None]:
# Installation Commands (Run if dependencies not installed)
# !pip install --upgrade pip setuptools
# !pip install -r requirements.txt
# 
# Expected dependencies:
# - datasets (HuggingFace datasets library)
# - PIL/Pillow (image processing)
# - numpy (numerical operations)
# - matplotlib (visualization)
# - torch (deep learning framework)
# - freetype-py (font rendering)
# - uharfbuzz (text shaping for complex scripts)
# - opencv-python (computer vision operations)

In [None]:
"""
IMPORTS & DEPENDENCIES
=======================
This section imports all required libraries for the OCR pipeline.

Library purposes:
- datasets: Load HuggingFace datasets (Nepali text corpus)
- PIL/Image: Image creation, manipulation, and format handling
- numpy: Numerical arrays and mathematical operations
- matplotlib: Data visualization and plotting
- random: Randomization for data shuffling and augmentation
- re: Regular expressions for text extraction and cleaning
- os/glob: File system operations and path handling
- torch: PyTorch deep learning framework
- torch.nn: Neural network layers and models
- freetype: Low-level font rendering engine
- uharfbuzz: Text shaping for complex scripts (Devanagari)
- cv2: OpenCV for image transformations (perspective, distortion)
- Counter: Efficient counting of character frequencies
"""
from datasets import load_dataset
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random
import re
import os
import glob
import torch
import torch.nn as nn
import freetype
import uharfbuzz as hb
import cv2
from collections import Counter

In [None]:
"""
LOAD NEPALI TEXT DATASET
==========================
Load the HuggingFace dataset containing real Nepali text samples.

Dataset Details:
- Source: "Sakonii/nepalitext-language-model-dataset"
- Contains: Real Nepali text in Devanagari script
- Purpose: Extract authentic vocabulary for OCR training
- Train split: Contains thousands of text samples

The dataset is publicly available and contains clean, real-world Nepali text
that we'll use to extract authentic words for synthetic image generation.
"""
# Load the NepaliText dataset
dataset = load_dataset("Sakonii/nepalitext-language-model-dataset")
train_texts = dataset["train"]["text"]
print(f"Loaded dataset with {len(train_texts)} training samples")

In [None]:
"""
CHARACTER ANALYSIS & FREQUENCY ANALYSIS
==========================================
Analyze and visualize the frequency distribution of characters in the dataset.

This section:
1. Defines a cleaning function to handle special characters
2. Counts character frequencies across all texts
3. Visualizes the top 50 most common characters

Technical Details:

DEVANAGARI UNICODE RANGE: U+0900 to U+097F (2304-2431 in decimal)
This range includes:
- Consonants (वर्ण): क, ख, ग, घ, etc.
- Vowels (स्वर): अ, आ, इ, ई, उ, ऊ, etc.
- Diacritics (मात्रा): ा, ि, ी, ु, ू, ृ, etc.
- Nukt consonants: ड़, ढ़, etc.

Characters Filtered Out:
- Newlines, tabs, carriage returns
- Zero-width characters (8203-8207): Used for text flow control
- Non-printable ASCII characters
- Any character outside Devanagari range (unless printable ASCII for punctuation)

Why This Cleaning?
- OCR models perform better with consistent character sets
- Zero-width characters don't contribute to visual appearance
- Reduces confusion between visually similar patterns
"""
# Clean character function - removes unwanted characters
def clean_char(c):
    """
    Remove non-essential characters that don't contribute to visual OCR training.
    
    Args:
        c (str): Single character to clean
        
    Returns:
        str: Cleaned character (empty string if filtered out)
    """
    # Remove whitespace characters that don't render
    if c in ["\n", "\t", "\r"]:
        return ""
    
    # Remove zero-width characters (used for bidirectional text control)
    # Unicode range: U+200B to U+200F
    if ord(c) in [8203, 8204, 8205, 8206, 8207]:
        return ""
    
    # Keep all Devanagari characters (U+0900 to U+097F)
    if 2304 <= ord(c) <= 2431:
        return c
    
    # Keep other printable characters (spaces, punctuation, etc.)
    if c.isprintable():
        return c
    
    # Filter out everything else
    return ""

# Count character frequencies across entire dataset
char_freq = Counter()
for text in train_texts:
    if isinstance(text, str):
        # Clean and update frequency counter
        cleaned = "".join(clean_char(c) for c in text)
        char_freq.update(cleaned)

print(f"✓ Unique cleaned characters: {len(char_freq)}")

# Extract top 50 most frequent characters for visualization
top_50 = char_freq.most_common(50)
chars, freqs = zip(*top_50)

# Visualize character frequency distribution
plt.figure(figsize=(14, 6))
plt.bar(chars, freqs)
plt.xticks(rotation=90, fontsize=12)
plt.title("Top 50 Characters (Cleaned)")
plt.ylabel("Frequency")
plt.xlabel("Characters")
plt.tight_layout()
plt.show()

print(f"\nCharacter frequency analysis complete. Top characters are visualized above.")

## Step 1: Extract 5000 Nepali Words

### Objective
Extract authentic Devanagari words from the dataset to use as training samples for synthetic image generation.

### Methodology
- Use regex pattern to identify contiguous sequences of Devanagari characters
- Filter by word length (2-30 characters) to exclude single characters and unrealistic sequences
- Create a vocabulary of unique words for synthetic training data

### Why Synthetic Data?
- **Controlled Augmentation**: We can vary fonts, sizes, rotations, and distortions
- **Unlimited Samples**: Generate as many training examples as needed
- **Authentic Text**: Uses real Nepali words, not randomly generated characters
- **Reproducibility**: Same vocabulary can be used for fair model comparisons

In [None]:
"""
EXTRACT DEVANAGARI WORDS FROM DATASET
=======================================
Extract authentic Nepali words from the real text corpus.

Process:
1. Define regex pattern to match Devanagari character sequences
2. Iterate through all dataset texts
3. Extract matching word sequences
4. Filter by length constraints
5. Remove duplicates using a set
6. Randomly shuffle and select 5000 words for training

Regex Pattern: r'[\u0900-\u097F]+'
- \u0900-\u097F: Unicode range for Devanagari script
- +: One or more consecutive Devanagari characters

Word Length Constraints:
- Minimum: 2 characters (avoid single-character words)
- Maximum: 30 characters (avoid unrealistic sequences)

Why Deduplication?
- Prevents bias towards frequently occurring words
- Ensures diverse vocabulary representation
- Reduces redundant synthetic image generation
"""

def extract_nepali_words(text):
    """
    Extract all Devanagari word sequences from a text string.
    
    Uses regex to find contiguous sequences of Devanagari characters.
    Filters by length to get reasonable word-like sequences.
    
    Args:
        text (str): Input text that may contain Devanagari and other characters
        
    Returns:
        list: Extracted Devanagari words meeting length criteria
    """
    if not isinstance(text, str):
        return []
    
    # Match one or more consecutive Devanagari characters (U+0900 to U+097F)
    matches = re.findall(r"[\u0900-\u097F]+", text)
    
    # Filter by length: keep only words between 2 and 30 characters
    return [w for w in matches if 2 <= len(w) <= 30]

# Extract all unique words from dataset
all_words = set()
print("Extracting Nepali words from dataset...")
print("=" * 70)

for i, text in enumerate(train_texts):
    words = extract_nepali_words(text)
    all_words.update(words)
    
    # Print progress at regular intervals
    if (i + 1) % 10000 == 0:
        print(f"  Processed {i + 1:6d} texts  |  Found {len(all_words):6d} unique words")

# Prepare training word list
all_words = list(all_words)
random.shuffle(all_words)  # Randomize order to avoid order bias
training_words = all_words[:5000]  # Select first 5000

print("=" * 70)
print(f"✓ Total unique words extracted: {len(all_words):6d}")
print(f"✓ Selected for training:         {len(training_words):6d} words")
print(f"✓ Sample words: {training_words[:10]}")
print("\nThese words will be used to generate synthetic training images.")

## Step 2: Synthetic Dataset Generator

### Overview
Create synthetic training images with proper Devanagari text rendering using HarfBuzz and various augmentation techniques.

### Key Features

#### Text Rendering with HarfBuzz
- **Complex Script Shaping**: Properly renders Devanagari ligatures and diacritics
- **Glyph Positioning**: Accurate placement of shaped glyphs on the canvas
- **Font Support**: Works with any TTF font that supports Devanagari

#### Data Augmentations
- **Blur**: Simulates camera focus variations (Gaussian blur)
- **Noise**: Adds Gaussian noise and salt-and-pepper noise
- **Rotation**: Random slight rotations (-7° to +7°)
- **Perspective Distortion**: Simulates viewing angle variations
- **Background Variation**: Random backgrounds to prevent overfitting to white/gray

#### Why These Augmentations?
- **Blur**: Handles out-of-focus camera captures
- **Noise**: Increases robustness to sensor noise and compression artifacts
- **Rotation**: Camera tilt and document orientation variations
- **Perspective**: Non-perpendicular document scanning
- **Background**: Prevents model from using background as discriminative feature

### Technical Implementation
Uses PIL for image manipulation, HarfBuzz for text shaping, OpenCV for perspective transforms.

In [None]:
"""
SYNTHETIC OCR DATASET GENERATOR CLASS
======================================
Generate synthetic training images with proper Devanagari rendering and augmentations.

Architecture:
1. Load TTF fonts from fonts directory
2. For each word:
   a. Render text using HarfBuzz shaping engine
   b. Apply random augmentations (blur, noise, rotation, distortion)
   c. Save image and corresponding label
   d. Track progress

Key Method: render_text_image()
- Uses HarfBuzz for proper Devanagari glyph shaping
- Handles complex ligatures and diacritics
- Applies FreeType for glyph rendering
- Composes glyphs on canvas with proper positioning

Augmentation Chain:
1. Gaussian blur (50% probability)
2. Rotation (-7° to +7°, 100% probability)
3. Perspective distortion (100% probability)
4. Gaussian + salt-and-pepper noise (100% probability)
"""

class SyntheticHarfBuzzOCRDatasetGenerator:
    """
    Generate synthetic OCR training dataset with proper Devanagari text rendering.
    
    This generator uses:
    - HarfBuzz for complex script shaping
    - FreeType for glyph rendering
    - PIL for image composition
    - OpenCV for geometric transformations
    - NumPy for noise generation
    """

    def __init__(
        self,
        strings,
        fonts_dir="fonts",
        output_dir="data/word_images",
        font_size_range=(40, 56),
        random_blur=True,
        random_noise=True,
        random_rotate=True,
        random_distortion=True,
        background_mode="random",
        max_image_size=1024
    ):
        """
        Initialize the synthetic dataset generator.
        
        Args:
            strings (list): List of text strings to render
            fonts_dir (str): Directory containing TTF font files
            output_dir (str): Directory to save generated images
            font_size_range (tuple): Min and max font sizes in points
            random_blur (bool): Apply Gaussian blur augmentation
            random_noise (bool): Apply Gaussian and salt-pepper noise
            random_rotate (bool): Apply random rotation (-7 to 7 degrees)
            random_distortion (bool): Apply perspective distortion
            background_mode (str): "white", "lightgray", or "random"
            max_image_size (int): Maximum image dimension to prevent memory issues
        """
        self.strings = strings
        
        # Load all available TTF fonts
        self.fonts = glob.glob(os.path.join(fonts_dir, "**/*.ttf"), recursive=True)
        if not self.fonts:
            raise ValueError(f"No fonts found in {fonts_dir}. Please download Devanagari fonts.")

        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

        # Augmentation parameters
        self.font_size_range = font_size_range
        self.random_blur = random_blur
        self.random_noise = random_noise
        self.random_rotate = random_rotate
        self.random_distortion = random_distortion
        self.background_mode = background_mode
        self.MAX_SIZE = max_image_size

    def _clamp_image_size(self, img):
        """
        Ensure image doesn't exceed maximum size.
        
        This prevents memory issues with very large generated images.
        Uses high-quality LANCZOS resampling for downscaling.
        
        Args:
            img (PIL.Image): Input image
            
        Returns:
            PIL.Image: Image with clamped dimensions
        """
        w, h = img.size
        if w > self.MAX_SIZE or h > self.MAX_SIZE:
            img.thumbnail((self.MAX_SIZE, self.MAX_SIZE), Image.LANCZOS)
        return img

    def generate_dataset(self):
        """
        Generate the complete synthetic dataset.
        
        For each input string:
        1. Render it as an image using render_text_image()
        2. Save image as PNG (indexed by 5-digit counter)
        3. Save label as text file with corresponding text
        4. Print progress every 500 images
        
        Output Structure:
            output_dir/
            ├── 00001.png
            ├── 00001.txt
            ├── 00002.png
            ├── 00002.txt
            └── ...
        """
        for idx, text in enumerate(self.strings, start=1):
            img = self.render_text_image(text)
            image_path = os.path.join(self.output_dir, f"{idx:05d}.png")
            label_path = os.path.join(self.output_dir, f"{idx:05d}.txt")
            
            img.save(image_path)
            with open(label_path, "w", encoding="utf-8") as f:
                f.write(text)
            
            if idx % 500 == 0:
                print(f"  [{idx:5d}/{len(self.strings)}] Generated images")

    def render_text_image(self, text, padding=20):
        """
        Render a single text string as an image with augmentations.
        
        Process:
        1. Select random font and size
        2. Shape text using HarfBuzz (handles complex Devanagari)
        3. Calculate required image dimensions
        4. Create background image
        5. Render glyphs using FreeType
        6. Apply augmentations:
           - Blur (50% probability)
           - Rotation (100%)
           - Perspective distortion (100%)
           - Noise (100%)
        
        Args:
            text (str): Devanagari text to render
            padding (int): Pixel padding around text
            
        Returns:
            PIL.Image: Rendered and augmented image
        """
        # Random font and size selection
        font_path = random.choice(self.fonts)
        font_size = random.randint(*self.font_size_range)
        
        # Initialize FreeType face
        face = freetype.Face(font_path)
        face.set_char_size(font_size * 64)  # FreeType uses 1/64 point units

        # === HARFBUZZ TEXT SHAPING ===
        # HarfBuzz properly shapes complex Devanagari text, handling:
        # - Ligatures (e.g., क्ष, त्र)
        # - Diacritics (e.g., ा, ि)
        # - Conjuncts (combined consonants)
        
        hb_blob = hb.Blob.from_file_path(font_path)
        hb_face = hb.Face(hb_blob, 0)
        hb_font = hb.Font(hb_face)
        hb_font.scale = (face.size.ascender, face.size.ascender)

        buf = hb.Buffer()
        buf.add_str(text)
        buf.guess_segment_properties()  # Auto-detect script and language
        hb.shape(hb_font, buf)

        infos = buf.glyph_infos  # Glyph indices and properties
        positions = buf.glyph_positions  # Positioning info (advance, offset)

        # Calculate required image dimensions
        width = sum(pos.x_advance for pos in positions) // 64 + 2*padding
        height = font_size + 2*padding

        # === CREATE BACKGROUND ===
        if self.background_mode == "white":
            img = Image.new("RGB", (width, height), "white")
        elif self.background_mode == "lightgray":
            img = Image.new("RGB", (width, height), "lightgray")
        else:
            # Random background with slight variation (200-255 intensity)
            arr = np.random.randint(200, 255, (height, width, 3), dtype=np.uint8)
            img = Image.fromarray(arr)

        # Starting position for glyph rendering
        x, y = padding, padding + font_size

        # === RENDER GLYPHS ===
        for info, pos in zip(infos, positions):
            glyph_index = info.codepoint
            face.load_glyph(glyph_index, freetype.FT_LOAD_RENDER)
            bitmap = face.glyph.bitmap
            top = face.glyph.bitmap_top
            left = face.glyph.bitmap_left

            if bitmap.width > 0 and bitmap.rows > 0:
                # Convert FreeType bitmap to PIL image
                glyph_img = Image.frombytes(
                    "L", 
                    (bitmap.width, bitmap.rows), 
                    bytes(bitmap.buffer)
                )
                # Create colored glyph (black text on transparent)
                colored_glyph = Image.new("RGB", glyph_img.size, "black")
                # Paste using alpha channel from glyph_img
                img.paste(colored_glyph, (int(x + left), int(y - top)), glyph_img)

            # Move cursor based on HarfBuzz positioning
            x += pos.x_advance / 64  # 1/64 point to pixel conversion
            y -= pos.y_advance / 64

        img = self._clamp_image_size(img)

        # === APPLY AUGMENTATIONS ===
        
        # 1. Gaussian Blur (50% probability)
        if self.random_blur and random.random() < 0.5:
            from PIL import ImageFilter
            img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))

        # 2. Random Rotation (-7 to +7 degrees)
        if self.random_rotate:
            angle = random.randint(-7, 7)
            img = img.rotate(angle, expand=True, fillcolor="white")
            img = self._clamp_image_size(img)

        # 3. Perspective Distortion
        if self.random_distortion:
            img = self.perspective_distortion(img)

        # 4. Noise Addition
        if self.random_noise:
            img = self.add_noise(img)

        return img

    def perspective_distortion(self, img):
        """
        Apply random perspective distortion to simulate camera angle variations.
        
        This transformation is useful for:
        - Simulating non-perpendicular document scanning
        - Handling camera tilt in real OCR scenarios
        - Increasing model robustness to viewing angles
        
        Process:
        1. Define source corners as image boundaries
        2. Randomly perturb destination corners (±10% of image size)
        3. Compute perspective transformation matrix
        4. Apply transformation using OpenCV warpPerspective
        
        Args:
            img (PIL.Image): Input image
            
        Returns:
            PIL.Image: Perspective-distorted image
        """
        img = self._clamp_image_size(img)
        w, h = img.size
        arr = np.array(img)
        
        # Maximum distortion amount (±10% of image size)
        shift = min(w, h) * 0.1

        # Source points (image corners)
        pts1 = np.float32([[0, 0], [w, 0], [0, h], [w, h]])
        
        # Destination points (randomly perturbed corners)
        pts2 = np.float32([
            [random.uniform(-shift, shift), random.uniform(-shift, shift)],
            [w + random.uniform(-shift, shift), random.uniform(-shift, shift)],
            [random.uniform(-shift, shift), h + random.uniform(-shift, shift)],
            [w + random.uniform(-shift, shift), h + random.uniform(-shift, shift)],
        ])
        
        # Compute perspective transformation matrix
        matrix = cv2.getPerspectiveTransform(pts1, pts2)
        
        # Apply perspective transformation (white border for out-of-bounds regions)
        warped = cv2.warpPerspective(
            arr, 
            matrix, 
            (w, h), 
            borderMode=cv2.BORDER_CONSTANT, 
            borderValue=(255, 255, 255)
        )
        return Image.fromarray(warped)

    def add_noise(self, img):
        """
        Add realistic noise to the image.
        
        This includes:
        1. Gaussian noise (zero-mean, σ=10)
           - Simulates sensor noise and compression artifacts
        
        2. Salt-and-pepper noise (2% of pixels)
           - 1% turned to white (salt)
           - 1% turned to black (pepper)
           - Simulates sensor spikes and transmission errors
        
        Args:
            img (PIL.Image): Input image
            
        Returns:
            PIL.Image: Noisy image
        """
        arr = np.array(img).astype(np.float32)
        
        # Gaussian noise (50% probability)
        if random.random() < 0.5:
            # Add Gaussian noise with std dev = 10
            arr += np.random.normal(0, 10, arr.shape)
        
        # Salt-and-pepper noise (50% probability)
        if random.random() < 0.5:
            amount = 0.02  # 2% of pixels affected
            num_salt = int(arr.size * amount * 0.5)
            num_pepper = int(arr.size * amount * 0.5)
            
            # Add salt (white) noise
            coords = [np.random.randint(0, i - 1, num_salt) for i in arr.shape]
            arr[tuple(coords)] = 255
            
            # Add pepper (black) noise
            coords = [np.random.randint(0, i - 1, num_pepper) for i in arr.shape]
            arr[tuple(coords)] = 0
        
        # Clamp values to valid range [0, 255]
        arr = np.clip(arr, 0, 255)
        return Image.fromarray(arr.astype(np.uint8))

## Step 3: Generate Dataset

### Execution
This step runs the dataset generator to create synthetic training images from the 5000 Nepali words extracted in Step 1.

### Process
- Initialize generator with augmentation parameters
- Generate images in parallel (if available)
- Save each image with corresponding label file
- Total output: 5000 images + 5000 label files

### Expected Duration
Depends on system performance, typically 5-30 minutes for 5000 images.

In [None]:
"""
GENERATE SYNTHETIC TRAINING DATASET
====================================
Execute the synthetic dataset generator to create training images.

Configuration:
- Font size range: 40-56 points (provides variation)
- All augmentations enabled:
  * Blur: Handles out-of-focus variations
  * Noise: Adds robustness to sensor noise
  * Rotation: Handles document tilt
  * Distortion: Handles perspective variations
- Background: Random (prevents background-based overfitting)
- Image size limit: 1024 pixels (memory management)

Output Structure:
    data/word_images/
    ├── 00001.png  (synthetic image)
    ├── 00001.txt  (label: "काठमाडौ" or similar)
    ├── 00002.png
    ├── 00002.txt
    ├── ...
    └── 05000.png / 05000.txt

Total Files Created: 10,000 (5,000 image-label pairs)
Average Image Size: ~50-150 KB per PNG

This dataset is ready for training the CRNN+LSTM model.
"""
print("=" * 70)
print("GENERATING SYNTHETIC DATASET (5000 WORD SAMPLES)")
print("=" * 70)

# Initialize generator with full augmentation pipeline
generator = SyntheticHarfBuzzOCRDatasetGenerator(
    strings=training_words,
    fonts_dir="fonts",
    output_dir="data/word_images",
    font_size_range=(40, 56),
    random_blur=True,
    random_noise=True,
    random_rotate=True,
    random_distortion=True,
    background_mode="random",
    max_image_size=1024
)

# Generate all images and labels
generator.generate_dataset()

print("\n✓ DATASET GENERATION COMPLETE!")
print("=" * 70)

# Verification: Count generated files
output_dir = "data/word_images"
image_files = [f for f in os.listdir(output_dir) if f.endswith(".png")]
label_files = [f for f in os.listdir(output_dir) if f.endswith(".txt")]

print(f"✓ Generated {len(image_files)} images")
print(f"✓ Generated {len(label_files)} labels")
print(f"\nDataset location: {output_dir}")
print(f"Ready for training!")

## Step 4: Create Charset

### Purpose
Extract all unique characters from the training vocabulary and create a charset file that will be used by the model.

### Why Charset?
- **Class Indexing**: Each character gets a unique index (0 to num_classes-1)
- **Output Layer**: Model's output layer must have size = num_classes + 1 (extra for CTC blank)
- **Decoding**: Used to convert predicted indices back to characters

### Charset File Format
- Single line containing all unique characters in order
- Encoding: UTF-8 (important for Devanagari)
- Used during training and inference for character-to-index mapping

In [None]:
"""
EXTRACT AND SAVE CHARACTER SET
================================
Create a charset from all unique characters in the training vocabulary.

Process:
1. Iterate through all training words
2. Collect unique characters
3. Sort alphabetically for consistent ordering
4. Save to charset.txt (one line, all characters)

Charset Details:
- num_classes = len(charset) + 1
  * Additional class for CTC blank token (index 0)
  * Character indices: 1 to len(charset)
  
Why +1 for CTC?
- CTC (Connectionist Temporal Classification) requires a blank token
- This token represents the absence of character (for variable-length alignment)
- Reserved as index 0 in the output layer

Character Ordering:
- Sorted alphabetically for consistency
- Enables reproducible index-to-character mapping
- Simplifies debugging and cross-model comparison
"""

# Extract unique characters from all training words
charset = set()
for word in training_words:
    charset.update(word)  # Add all characters in the word

# Sort for consistent, reproducible ordering
charset = sorted(list(charset))

# Save charset to file
with open("charset.txt", "w", encoding="utf-8") as f:
    f.write("".join(charset))

print("=" * 70)
print("CHARACTER SET EXTRACTION")
print("=" * 70)
print(f"✓ Charset: {len(charset)} unique characters")
print(f"✓ num_classes = {len(charset) + 1} (including CTC blank token at index 0)")
print(f"\nCharset file saved to: charset.txt")
print(f"\nCharacter breakdown:")
print(f"  - Index 0: CTC blank token (reserved for temporal alignment)")
print(f"  - Index 1-{len(charset)}: Character indices")
print(f"\nExample characters: {charset[:10]}")
print("=" * 70)

## Step 5: Model Architecture

### Overview
Define the CRNN (Convolutional Recurrent Neural Network) model architecture for Devanagari OCR.

### Architecture Stack
1. **CNN Feature Extractor** (CRNNFeatureExtractor)
   - Extracts visual features from image
   - Output: Sequence of feature vectors

2. **Bidirectional LSTM** (BidirectionalLSTM) × 2
   - Captures sequential dependencies
   - First layer: hidden_size → hidden_size
   - Second layer: hidden_size → num_classes

3. **CTC Loss** (Connectionist Temporal Classification)
   - Handles variable-length character sequences
   - No need for explicit character alignment

### Why CRNN+LSTM?
- **CNNs**: Excellent for spatial feature extraction
- **RNNs**: Model sequential/temporal dependencies
- **Bidirectional**: Uses context from both directions
- **LSTM**: Captures long-range dependencies better than vanilla RNN
- **CTC**: Aligns variable-length sequences without explicit alignment

### Model Flow
Image → CNN (spatial features) → BiLSTM × 2 (sequence modeling) → Output (character probabilities)

In [None]:
"""
NEURAL NETWORK MODEL ARCHITECTURE
===================================
Define CRNN + BiLSTM + CTC model for Devanagari OCR.

This implementation follows the CRNN architecture from:
"An End-to-End Trainable Neural Network for Image-based Sequence Recognition"

Architecture Layers:
├── CNN Feature Extractor
│   ├── Conv2d(1, 64, 3×3)    - Initial feature extraction
│   ├── Conv2d(64, 128, 3×3)  - Mid-level features
│   ├── Conv2d(128, 256, 3×3) - Higher-level patterns
│   ├── Conv2d(256, 256, 3×3) - Deeper patterns
│   ├── Conv2d(256, 512, 3×3) - Complex feature combinations
│   ├── Conv2d(512, 512, 3×3) - Final CNN layer
│   └── Conv2d(512, 512, 2×2) - Dimension reduction
│
├── BiLSTM Layer 1
│   └── 512 → 256 hidden units (bidirectional)
│
├── BiLSTM Layer 2
│   └── 256 → num_classes output
│
└── CTC Loss Function
    └── For alignment-free sequence learning
"""

class CRNNFeatureExtractor(nn.Module):
    """
    CNN backbone for feature extraction.
    
    This convolutional encoder progressively extracts features:
    - Early layers: Edge and small pattern detection
    - Middle layers: Local features and textures
    - Deep layers: Semantic features and character patterns
    
    Architecture Details:
    - Kernel size: 3×3 (good for Devanagari character features)
    - Pooling: 2×2 (downsamples spatial dimensions)
    - BatchNorm: Stabilizes training and enables higher learning rates
    - Final pooling: (2,1) in last stages (preserve horizontal resolution)
    
    Why (2,1) pooling?
    - Devanagari characters have distinct vertical (height) patterns
    - Horizontal resolution is critical for sequence learning
    - Reduces height but preserves width for LSTM processing
    """
    
    def __init__(self, img_channels=1):
        """
        Initialize CNN feature extractor.
        
        Args:
            img_channels (int): Input channels (1 for grayscale, 3 for RGB)
        """
        super().__init__()
        self.cnn = nn.Sequential(
            # Block 1: 1 → 64 channels
            nn.Conv2d(img_channels, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # H/2, W/2
            
            # Block 2: 64 → 128 channels
            nn.Conv2d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # H/4, W/4
            
            # Block 3: 128 → 256 channels
            nn.Conv2d(128, 256, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),  # Normalize before next layer
            
            # Block 4: 256 → 256 channels (deeper feature extraction)
            nn.Conv2d(256, 256, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),  # H/2, W same (preserve horizontal)
            
            # Block 5: 256 → 512 channels
            nn.Conv2d(256, 512, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            
            # Block 6: 512 → 512 channels (final deep layer)
            nn.Conv2d(512, 512, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),  # H/2, W same
            
            # Block 7: Dimension reduction (2×2 kernel)
            nn.Conv2d(512, 512, 2, stride=1, padding=0),
            nn.ReLU()
        )

    def forward(self, x):
        """
        Forward pass through CNN.
        
        Input shape: (batch, channels, height, width)
        Output shape: (width, batch, channels)
        
        The output is reshaped for LSTM:
        - Collapses height dimension (spatial averaging)
        - Produces sequence along width dimension
        - Width position becomes temporal dimension for LSTM
        
        Args:
            x (torch.Tensor): Input images (B, C, H, W)
            
        Returns:
            torch.Tensor: Sequence of feature vectors (W, B, C)
        """
        conv_output = self.cnn(x)  # (B, 512, H', W')
        b, c, h, w = conv_output.size()
        
        # Average across height dimension (treat height as spatial, not temporal)
        conv_output = conv_output.mean(2)  # (B, 512, W')
        
        # Rearrange for LSTM: (W', B, 512)
        # Width becomes temporal dimension (character sequence)
        return conv_output.permute(2, 0, 1)


class BidirectionalLSTM(nn.Module):
    """
    Bidirectional LSTM layer with output projection.
    
    Purpose: Capture sequential dependencies in character sequences.
    
    Why Bidirectional?
    - Forward LSTM: Uses context from left (previous characters)
    - Backward LSTM: Uses context from right (future characters)
    - Combined: Richer representation using full context
    
    Process:
    1. LSTM processes sequence (forward + backward)
    2. Concatenate forward and backward hidden states
    3. Project to output dimension using linear layer
    
    Hidden State Concatenation:
    - LSTM output shape: (seq_len, batch, 2*hidden_size)
    - Linear layer maps: 2*hidden_size → output_size
    """
    
    def __init__(self, input_size, hidden_size, output_size):
        """
        Initialize BiLSTM layer.
        
        Args:
            input_size (int): Dimension of input features
            hidden_size (int): Number of LSTM hidden units (each direction)
            output_size (int): Output dimension after projection
        """
        super().__init__()
        # Bidirectional LSTM: outputs 2*hidden_size (forward + backward concatenated)
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, bidirectional=True)
        
        # Project BiLSTM output to desired dimension
        # Input: 2*hidden_size (bidirectional concatenation)
        # Output: output_size
        self.embedding = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        """
        Forward pass through BiLSTM.
        
        Args:
            x (torch.Tensor): Input sequence (seq_len, batch, input_size)
            
        Returns:
            torch.Tensor: Projected output (seq_len, batch, output_size)
        """
        # BiLSTM forward
        recurrent, _ = self.rnn(x)  # (seq_len, batch, 2*hidden_size)
        
        # Project to output dimension
        return self.embedding(recurrent)  # (seq_len, batch, output_size)


class OCRModel(nn.Module):
    """
    Complete OCR model: CNN + BiLSTM + CTC.
    
    Architecture Flow:
    Input Image
        ↓
    CNN Feature Extractor (CRNNFeatureExtractor)
        ↓ Feature Sequence (width, batch, 512)
    BiLSTM Layer 1 (512 → 256)
        ↓ (width, batch, 256)
    BiLSTM Layer 2 (256 → num_classes)
        ↓ (width, batch, num_classes)
    CTC Loss → Character Predictions
    
    Why This Architecture?
    - CNN: Excellent for image-to-feature conversion
    - LSTM: Models character sequence dependencies
    - CTC: Handles alignment automatically (no explicit character positions needed)
    
    Training:
    - Forward pass outputs logits (not normalized)
    - Converted to log probabilities for CTC loss
    - CTC loss aligns predictions with ground truth
    """
    
    def __init__(self, num_classes, img_channels=1, hidden_size=256):
        """
        Initialize OCR model.
        
        Args:
            num_classes (int): Number of character classes (including CTC blank)
            img_channels (int): Input image channels (1 for grayscale)
            hidden_size (int): LSTM hidden layer dimension
        """
        super().__init__()
        self.cnn = CRNNFeatureExtractor(img_channels)
        
        # Stacked BiLSTM layers for deeper sequence modeling
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, hidden_size, hidden_size),  # 512 → 256
            BidirectionalLSTM(hidden_size, hidden_size, num_classes)  # 256 → num_classes
        )
        
        # CTC Loss for alignment-free training
        # blank=0 corresponds to CTC blank token
        # zero_infinity=True: handles inf loss from very incorrect predictions
        self.ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

    def forward(self, x):
        """
        Forward pass: image to character predictions.
        
        Args:
            x (torch.Tensor): Input images (batch, channels, height, width)
            
        Returns:
            torch.Tensor: Character logits (seq_len, batch, num_classes)
        """
        features = self.cnn(x)  # (width, batch, 512)
        return self.rnn(features)  # (width, batch, num_classes)

    def compute_ctc_loss(self, preds, targets, pred_lengths, target_lengths):
        """
        Compute CTC loss for training.
        
        CTC (Connectionist Temporal Classification) Loss:
        - Handles variable-length sequences
        - Automatically finds best alignment between predictions and targets
        - No explicit character position labels needed
        
        Args:
            preds (torch.Tensor): Model predictions (seq_len, batch, num_classes)
            targets (torch.Tensor): Ground truth character indices (batch, max_target_len)
            pred_lengths (torch.Tensor): Actual prediction lengths (batch,)
            target_lengths (torch.Tensor): Actual target lengths (batch,)
            
        Returns:
            torch.Tensor: CTC loss value (scalar)
        """
        # Convert logits to log-softmax (CTC expects log probabilities)
        preds_log = preds.log_softmax(2)
        
        # Compute CTC loss with automatic alignment
        return self.ctc_loss(preds_log, targets, pred_lengths, target_lengths)

print("✓ Model architecture classes defined successfully")
print("\nModel Summary:")
print("├── CRNNFeatureExtractor: CNN backbone (1-channel → 512-channel features)")
print("├── BidirectionalLSTM: Bidirectional LSTM with output projection")
print("├── OCRModel: Complete pipeline with CTC loss")
print("\nModel is ready for training!")

## Step 6: Create Configuration & Training Setup

### Configuration Parameters
Create YAML configuration file with all training hyperparameters.

### Key Parameters
- **Model**: num_classes, hidden_size, image dimensions
- **Training**: batch_size, epochs, learning_rate, optimizer settings
- **Data**: paths to fonts, images, charset
- **Augmentation**: font size range, distortion settings

### Next Steps
After this notebook:
1. Run `python scripts/train_word_ocr.py` to start training
2. Monitor validation metrics
3. Save best model checkpoint
4. Evaluate on test set

In [None]:
"""
CREATE TRAINING CONFIGURATION FILE
====================================
Generate YAML configuration with all hyperparameters and paths.

Configuration Components:

1. MODEL ARCHITECTURE
   - num_classes: Character classes + CTC blank (important: +1)
   - hidden_size: LSTM hidden units (256 is standard, can try 128/512)
   - num_channels: 1 (grayscale), 3 (RGB)

2. IMAGE DIMENSIONS
   - img_height: 32 pixels (standard for OCR)
   - img_width: 256 pixels (character sequence length)
   - Aspect ratio: 8:1 (encourages wide, short images for text)

3. TRAINING HYPERPARAMETERS
   - batch_size: 64 (balance between memory and gradient variance)
   - epochs: 50 (may need adjustment based on convergence)
   - learning_rate: 0.001 (Adam default, start conservative)
   - weight_decay: 1e-5 (L2 regularization for overfitting prevention)
   
4. LEARNING RATE SCHEDULER
   - scheduler_step: 15 epochs
   - scheduler_gamma: 0.5 (reduce LR by 50% every 15 epochs)
   - Helps convergence in later epochs and prevents divergence

5. DATA PATHS
   - Points to generated dataset and charset

Why These Hyperparameters?
- batch_size=64: Good compromise (not too small=noisy gradients, not too large=memory)
- lr=0.001: Standard Adam learning rate
- epochs=50: Typical for this dataset size with early stopping
- scheduler_step=15: Reduce learning rate after 15 epochs of training
- weight_decay=1e-5: Gentle regularization (not too strong)

Customization Tips:
- Increase batch_size if GPU memory allows (improves gradient stability)
- Decrease learning_rate if training is unstable
- Increase epochs if validation loss still decreasing at epoch 50
"""

config_content = f"""# CRNN OCR Configuration for Devanagari
# ======================================
# This configuration file specifies all training hyperparameters

# MODEL ARCHITECTURE
# ------------------
# Number of output classes = unique_characters + 1 (for CTC blank token)
# The +1 is critical for CTC loss computation
num_classes: {len(charset) + 1}

# Number of input channels (1=grayscale, 3=RGB)
# Grayscale is standard for OCR as it's more robust to lighting variations
num_channels: 1

# LSTM hidden layer dimensions
# Typical values: 128, 256, 512
# Larger = more model capacity but slower training and more parameters
hidden_size: 256

# IMAGE DIMENSIONS
# ----------------
# Input image height (in pixels)
# 32 is standard for OCR tasks
img_height: 32

# Input image width (in pixels)
# Should accommodate maximum word length expected
# 256 pixels ≈ ~40-50 characters (depends on font size)
img_width: 256

# TRAINING PARAMETERS
# -------------------
# Batch size for training
# Larger batch = more stable gradients, more GPU memory
# Typical range: 32-128 depending on GPU
batch_size: 64

# Number of training epochs
# One epoch = one complete pass through dataset
# Typical range: 30-100 (with early stopping)
epochs: 50

# Initial learning rate for Adam optimizer
# Typical range: 0.0001 - 0.01
# 0.001 is a good starting point
learning_rate: 0.001

# Weight decay (L2 regularization)
# Prevents overfitting by penalizing large weights
# Typical range: 1e-4 to 1e-6
weight_decay: 1e-5

# LEARNING RATE SCHEDULER
# -----------------------
# Reduce learning rate after this many epochs without improvement
# Helps fine-tune model in later training stages
scheduler_step: 15

# Learning rate decay factor
# new_lr = old_lr * scheduler_gamma
# 0.5 means 50% reduction (LR is halved)
scheduler_gamma: 0.5

# DATA PATHS
# ----------
# Number of training samples
train_samples: 5000

# Images per word (1 means one synthetic image per word)
samples_per_word: 1

# Directory containing TTF fonts for rendering
fonts_dir: "fonts"

# Directory containing generated word images
output_dir: "data/word_images"

# Path to charset file (one line of all characters)
charset_path: "charset.txt"

# AUGMENTATION PARAMETERS
# -----------------------
# Font size range for synthetic image generation (in points)
# Smaller = more variation, but might be too small to read
# Larger = less variation, might be too large for image size
font_size_range: [40, 56]

# Whether to apply data augmentations during training
apply_augmentation: true
"""

# Save configuration file
with open("config.yaml", "w") as f:
    f.write(config_content)

print("=" * 70)
print("CONFIGURATION FILE CREATED")
print("=" * 70)
print(f"\n✓ config.yaml created successfully")
print(f"\nKey Configuration Summary:")
print(f"  Model:")
print(f"    - Num Classes: {len(charset) + 1} (including CTC blank)")
print(f"    - Hidden Size: 256")
print(f"    - Input Channels: 1 (grayscale)")
print(f"\n  Image Dimensions:")
print(f"    - Height: 32 pixels")
print(f"    - Width: 256 pixels")
print(f"\n  Training:")
print(f"    - Batch Size: 64")
print(f"    - Epochs: 50")
print(f"    - Learning Rate: 0.001")
print(f"    - Training Samples: 5,000")
print(f"\n  Dataset:")
print(f"    - Image Dir: data/word_images/")
print(f"    - Charset Size: {len(charset)} characters")
print(f"    - Charset File: charset.txt")
print(f"\n" + "=" * 70)
print("NEXT STEPS:")
print("=" * 70)
print("1. Ensure you have the following ready:")
print("   - fonts/ directory with Devanagari TTF fonts")
print("   - data/word_images/ (generated in Step 3)")
print("   - charset.txt (generated in Step 4)")
print("\n2. Create scripts/train_word_ocr.py with training loop using:")
print("   - OCRModel defined above")
print("   - config.yaml parameters")
print("   - DataLoader for batch training")
print("   - Adam optimizer")
print("   - CTC loss computation")
print("\n3. Run training:")
print("   $ python scripts/train_word_ocr.py")
print("\n4. Monitor:")
print("   - Training and validation CTC loss")
print("   - Character error rate (CER)")
print("   - Model checkpoints")
print("=" * 70)

## Summary & Next Steps

### Pipeline Completion
✅ **Step 1**: Extracted 5,000 unique Nepali words from authentic dataset
✅ **Step 2**: Defined synthetic dataset generator with HarfBuzz shaping
✅ **Step 3**: Generated 5,000 synthetic training images with augmentations
✅ **Step 4**: Created character set from vocabulary (charset.txt)
✅ **Step 5**: Designed CRNN+BiLSTM+CTC model architecture
✅ **Step 6**: Generated training configuration (config.yaml)

### Files Created
- `data/word_images/` - 5,000 pairs of images (.png) and labels (.txt)
- `charset.txt` - All unique characters, one line
- `config.yaml` - Training configuration and hyperparameters

### Ready for Training
The pipeline is now ready for the training script (`scripts/train_word_ocr.py`) which should:
1. Load images from `data/word_images/`
2. Create DataLoader with batch processing
3. Initialize OCRModel with num_classes
4. Train using CTC loss
5. Save model checkpoints
6. Evaluate on validation set

### Key Metrics to Track
- **CTC Loss**: Should decrease over epochs
- **Character Error Rate (CER)**: Percentage of incorrect character predictions
- **Word Accuracy**: Percentage of fully correct predictions
- **Validation Loss**: Should not increase (indicates overfitting)

### Future Enhancements
- Increase dataset size (10,000+ images)
- Try different architectures (ResNet backbone, Transformer)
- Add more fonts and augmentations
- Fine-tune on real OCR data
- Implement inference pipeline
- Convert to ONNX for deployment