# Handwritten Character Recognition - Inference

This notebook demonstrates how to use trained models for handwritten character recognition. It covers:
- Loading pre-trained model weights.
- Preparing single images for inference.
- Performing predictions on single images.
- A detailed function for extracting and visualizing characters from an image of text, along with their predictions.

## 1. Imports and Setup

In [None]:
# General utilities
import os
import random
import copy # For load_model
import time # For load_model (though not strictly used in inference part)

# Image processing and display
import matplotlib.pyplot as plt
from PIL import Image
import cv2 # OpenCV for image operations
import numpy as np

# PyTorch essentials
import torch
import torch.nn as nn
import torch.optim as optim # For load_model compatibility
import torch.nn.functional as F
import torchvision.datasets as datasets # For get_class_labels_from_dir
import torchvision.transforms as transforms
import torchvision.models as models
from pathlib import Path 
import matplotlib.patches as patches # For detailed visualization

## 2. Device Configuration

In [None]:
# Device configuration
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

## 3. Model Architectures

The definitions of the models are needed to load the saved weights. These definitions must match the ones used during training. The following models expect 3-channel input images as prepared by the data pipeline in `training.ipynb`.

### 3.1. Custom CNNs (`LetterCNN64`, `ImprovedLetterCNN`)

In [None]:
class LetterCNN64(nn.Module):
    def __init__(self, num_classes):
        super(LetterCNN64, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = self.relu4(self.fc1(x))
        x = self.fc2(x)
        return x

class ImprovedLetterCNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedLetterCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.bn_fc1 = nn.BatchNorm1d(1024)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = x.view(-1, 256 * 4 * 4)
        x = self.dropout1(self.relu_fc1(self.bn_fc1(self.fc1(x))))
        x = self.dropout2(self.relu_fc2(self.bn_fc2(self.fc2(x))))
        x = self.fc3(x)
        return x

### 3.2. VGG19 Transfer Learning Model (`VGG19HandwritingModel`)

In [None]:
class VGG19HandwritingModel(nn.Module):
    def __init__(self, num_classes, device, pretrained=True):
        super(VGG19HandwritingModel, self).__init__()
        self.device = device
        vgg19 = models.vgg19_bn(weights=models.VGG19_BN_Weights.IMAGENET1K_V1 if pretrained else None)
        vgg19 = vgg19.to(device)
        self.features = vgg19.features
        # Note: This model definition assumes 3-channel input as prepared by the pipeline.
        # If using a model trained with 1-channel input VGG, the first conv layer would need modification here too.
        if pretrained:
            for param in self.features.parameters():
                param.requires_grad = False
        num_features_output = 512 * 2 * 2 
        self.classifier = nn.Sequential(
            nn.Linear(num_features_output, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        ).to(device)
        self._initialize_weights() # From merged code, good practice

    def _initialize_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 4. Inference Utilities

In [None]:
def load_model_for_inference(model, checkpoint_path):
    """Loads a model checkpoint for inference. Optimizer and scheduler states are ignored if not present."""
    if not os.path.exists(checkpoint_path):
        print(f"ERROR: Checkpoint path {checkpoint_path} does not exist. Returning initial model.")
        return None # Return None if checkpoint not found
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # Fallback for checkpoints that are just the state_dict itself
            model.load_state_dict(checkpoint)
        model.eval() # IMPORTANT: Set model to evaluation mode
        print(f"Model loaded successfully from {checkpoint_path} and set to evaluation mode.")
        # Optionally print other info from checkpoint if useful for context
        if 'epoch' in checkpoint: print(f"  Checkpoint saved at epoch: {checkpoint['epoch']}")
        if 'accuracy' in checkpoint: print(f"  Checkpoint validation accuracy (if saved): {checkpoint['accuracy']:.4f}")
        return model
    except Exception as e:
        print(f"ERROR loading model from {checkpoint_path}: {e}")
        return None # Return None on error

def get_class_labels_from_dir(data_root_dir):
    """Gets class labels from the folder names in data_root_dir (ImageFolder structure)."""
    if not os.path.exists(data_root_dir) or not os.path.isdir(data_root_dir):
        print(f"Error: Data root directory '{data_root_dir}' not found or not a directory.")
        return []
    try:
        # Create a temporary ImageFolder dataset just to get class names
        # This assumes that subdirectories in data_root_dir are the class names
        temp_dataset = datasets.ImageFolder(root=data_root_dir)
        return temp_dataset.classes
    except Exception as e:
        print(f"Error getting class labels from '{data_root_dir}': {e}")
        return []

def prepare_image_for_inference(image_path, image_size=(64, 64)):
    """Loads an image, converts to grayscale, resizes, and prepares it for 3-channel model inference."""
    if not os.path.exists(image_path):
        print(f"Error: Image path '{image_path}' not found.")
        return None
    try:
        img = Image.open(image_path).convert('L') # Convert to grayscale
        # Normalization parameters should be consistent with training (ImageNet stats for 3-channel models here)
        normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        inference_transform_pipeline = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(), 
            transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x), # Ensure 3 channels
            normalize_transform
        ])
        img_tensor = inference_transform_pipeline(img)
        return img_tensor.unsqueeze(0).to(device) # Add batch dimension and send to device
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None

## 5. Single Image Inference Example

This section demonstrates loading a trained model and performing inference on a single image.

In [None]:
print("--- Single Image Inference Example ---")

# --- USER CONFIGURATION REQUIRED --- #
# 1. Path to the saved model checkpoint (e.g., from './model_checkpoints/cnn/best_model.pth')
user_model_checkpoint_path = "./model_checkpoints/cnn/best_model.pth"  # <<< UPDATE THIS PATH

# 2. Path to the root directory of the dataset used for training (to get class labels)
#    (e.g., './datasets/handwritten-english/augmented_images/augmented_images1')
user_data_root_for_labels = "./datasets/handwritten-english/augmented_images/augmented_images1"  # <<< UPDATE THIS PATH

# 3. Path to the image you want to classify
user_image_for_prediction = ""  # <<< UPDATE THIS PATH (e.g., "./path/to/your/image.png")
# --- END OF USER CONFIGURATION --- #

loaded_inference_model = None
class_labels_for_inference = []

if not os.path.exists(user_model_checkpoint_path):
    print(f"ERROR: Model checkpoint '{user_model_checkpoint_path}' not found. Please update the path.")
elif not os.path.exists(user_data_root_for_labels):
    print(f"ERROR: Data root for labels '{user_data_root_for_labels}' not found. Please update the path.")
else:
    class_labels_for_inference = get_class_labels_from_dir(user_data_root_for_labels)
    if not class_labels_for_inference:
        print("ERROR: Could not retrieve class labels.")
    else:
        num_classes = len(class_labels_for_inference)
        print(f"Successfully loaded {num_classes} class labels. Example: {class_labels_for_inference[:5]}")
        
        # Initialize model architecture - this MUST match the checkpoint's model type
        # Assuming ImprovedLetterCNN for this example. Change if your checkpoint is for a different model.
        # E.g., for VGG19: model_architecture = VGG19HandwritingModel(num_classes, device, pretrained=False).to(device)
        model_architecture = ImprovedLetterCNN(num_classes).to(device)
        
        loaded_inference_model = load_model_for_inference(model_architecture, user_model_checkpoint_path)

if loaded_inference_model and class_labels_for_inference:
    if not user_image_for_prediction or not os.path.exists(user_image_for_prediction):
        print(f"INFO: 'user_image_for_prediction' ('{user_image_for_prediction}') is not set or does not exist.")
        print("Please provide a valid image path to perform inference.")
    else:
        input_image_tensor = prepare_image_for_inference(user_image_for_prediction)
        if input_image_tensor is not None:
            with torch.no_grad():
                outputs = loaded_inference_model(input_image_tensor)
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                confidence, predicted_idx = torch.max(probabilities, 1)
                predicted_label = class_labels_for_inference[predicted_idx.item()]
                confidence_percent = confidence.item() * 100

                print(f"\n--- Prediction Results for: {user_image_for_prediction} ---")
                print(f"Predicted Class: {predicted_label}")
                print(f"Confidence: {confidence_percent:.2f}%")

                try:
                    img_to_display = Image.open(user_image_for_prediction)
                    plt.figure(figsize=(3,3))
                    plt.imshow(img_to_display)
                    plt.title(f"Predicted: {predicted_label} ({confidence_percent:.2f}%)")
                    plt.axis('off'); plt.show()
                except Exception as e_disp: print(f"Error displaying image: {e_disp}")
        else:
            print(f"Could not prepare image '{user_image_for_prediction}' for inference.")
else:
    print("Cannot proceed with single image inference due to model or class label loading issues.")

## 6. Detailed Multi-Character Image Analysis

The `extract_letters_detailed_visualization` function segments characters from an image, classifies each segment, and visualizes the intermediate processing steps. This is useful for understanding the segmentation and classification process in more detail.

In [None]:
def extract_letters_detailed_visualization(image_path, model, class_list, output_dir="extracted_letters_viz", spacing=0):
    """
    Extract individual letters from a handwritten text image and classify them.
    Includes extensive visualization steps, saving intermediate images.
    The 'model' parameter should be the pre-loaded model for inference.
    The 'class_list' provides the mapping from prediction index to character label.
    """
    if model is None: print("Model not provided to extract_letters_detailed_visualization."); return []
    if not class_list: print("Class list not provided to extract_letters_detailed_visualization."); return []
    
    model.eval() 
    Path(output_dir).mkdir(exist_ok=True)
    image = cv2.imread(image_path)
    if image is None: raise FileNotFoundError(f"Could not read image: {image_path}")

    print(f"Visualizing steps for {image_path}, outputting to {output_dir}/")
    # Step 1: Original Image
    plt.figure(figsize=(10,5)); plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)); plt.title("Original Image"); plt.savefig(f"{output_dir}/1_original.png"); plt.show()

    # Step 2: Grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    plt.figure(figsize=(10,5)); plt.imshow(gray, cmap='gray'); plt.title("Grayscale Image"); plt.savefig(f"{output_dir}/2_grayscale.png"); plt.show()

    # Step 3a: Binary for Contours (Inverted OTSU)
    _, binary_contours = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    plt.figure(figsize=(10,5)); plt.imshow(binary_contours, cmap='gray'); plt.title("Binary for Contours (Inverted OTSU)"); plt.savefig(f"{output_dir}/3a_binary_contours.png"); plt.show()

    # Step 3b: Binary for Model Input (Non-inverted OTSU)
    _, binary_model_prep = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    plt.figure(figsize=(10,5)); plt.imshow(binary_model_prep, cmap='gray'); plt.title("Binary for Model Input (Non-inverted OTSU)"); plt.savefig(f"{output_dir}/3b_binary_model_prep.png"); plt.show()

    # Step 4: Find & Visualize All Contours
    contours, _ = cv2.findContours(binary_contours, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contour_img = image.copy(); cv2.drawContours(contour_img, contours, -1, (0,255,0), 2)
    plt.figure(figsize=(10,5)); plt.imshow(cv2.cvtColor(contour_img, cv2.COLOR_BGR2RGB)); plt.title(f"All Contours ({len(contours)})"); plt.savefig(f"{output_dir}/4_all_contours.png"); plt.show()

    # Step 5: Filter & Visualize Filtered Contours
    min_area = 50
    letter_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > min_area]
    print(f"Found {len(contours)} total contours, {len(letter_contours)} after filtering by min area {min_area}")
    filtered_contour_img = image.copy(); cv2.drawContours(filtered_contour_img, letter_contours, -1, (0,255,0), 2)
    plt.figure(figsize=(10,5)); plt.imshow(cv2.cvtColor(filtered_contour_img, cv2.COLOR_BGR2RGB)); plt.title(f"Filtered Contours ({len(letter_contours)})"); plt.savefig(f"{output_dir}/5_filtered_contours.png"); plt.show()

    if not letter_contours: print("No letter contours after filtering."); return []
    letter_contours = sorted(letter_contours, key=lambda cnt: cv2.boundingRect(cnt)[0])

    # Step 6: Visualize Extracted ROIs before processing
    num_letters_viz = len(letter_contours)
    if num_letters_viz > 0:
        fig_rows_viz = max(1, (num_letters_viz + 4) // 5); fig_cols_viz = min(5, num_letters_viz)
        plt.figure(figsize=(15, 3 * fig_rows_viz)); plt.suptitle("Extracted ROIs Before Processing", fontsize=16)
        for i_viz, c_viz in enumerate(letter_contours):
            x_viz, y_viz, w_viz, h_viz = cv2.boundingRect(c_viz)
            roi_viz = binary_model_prep[y_viz:y_viz+h_viz, x_viz:x_viz+w_viz]
            if i_viz < fig_rows_viz*fig_cols_viz: plt.subplot(fig_rows_viz,fig_cols_viz,i_viz+1); plt.imshow(roi_viz,cmap='gray'); plt.title(f"ROI {i_viz}"); plt.axis('off')
        plt.tight_layout(rect=[0,0,1,0.95]); plt.savefig(f"{output_dir}/6_extracted_rois.png"); plt.show()

    letters_info_final = []
    model_input_images_viz = []
    # This transform should match prepare_image_for_inference's core logic for 3-channel models
    inference_transform_detailed = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x_tensor: x_tensor.repeat(3,1,1) if x_tensor.size(0)==1 else x_tensor),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])

    for i, contour in enumerate(letter_contours):
        x, y, w, h = cv2.boundingRect(contour)
        if w < 5 or h < 5: continue
        letter_roi = binary_model_prep[y:y+h, x:x+w]
        letter_pil = Image.fromarray(letter_roi).convert('L')
        target_s = 64-spacing
        if w > h: new_w,new_h = target_s,int((h/w)*target_s)
        else: new_h,new_w = target_s,int((w/h)*target_s)
        resized = letter_pil.resize((new_w,new_h), Image.LANCZOS)
        padded = Image.new('L', (target_s,target_s), 255) # White background
        px,py = (target_s-new_w)//2, (target_s-new_h)//2
        padded.paste(resized, (px,py))
        final_img = Image.new('L', (64,64), 255); spacing_offset=spacing//2
        final_img.paste(padded, (spacing_offset,spacing_offset))
        model_input_images_viz.append(np.array(final_img))
        letter_tensor = inference_transform_detailed(final_img).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(letter_tensor)
            probabs = torch.nn.functional.softmax(outputs,dim=1)
            conf,pred_idx = torch.max(probabs,1)
            char = class_list[pred_idx.item()]
            print(f"Segment {i}: Classified as '{char}' with conf {conf.item():.2f}")
        letters_info_final.append((char, (x,y,w,h)))

    # Step 7: Visualize Model Inputs with Borders
    if model_input_images_viz:
        fig_rows_viz = max(1, (len(model_input_images_viz) + 4) // 5); fig_cols_viz = min(5, len(model_input_images_viz))
        plt.figure(figsize=(15, 3*fig_rows_viz)); plt.suptitle("Processed Segments for Model", fontsize=16)
        for i_viz, img_viz in enumerate(model_input_images_viz):
            if i_viz < fig_rows_viz*fig_cols_viz: 
                ax=plt.subplot(fig_rows_viz,fig_cols_viz,i_viz+1); plt.imshow(img_viz,cmap='gray')
                rect_o=patches.Rectangle((0,0),63,63,lw=1,edgecolor='r',fc='none'); ax.add_patch(rect_o)
                in_s=64-spacing-1; sp_h=spacing//2
                rect_i=patches.Rectangle((sp_h,sp_h),in_s,in_s,lw=1,edgecolor='b',fc='none'); ax.add_patch(rect_i)
                plt.title(f"Input {i_viz}: {letters_info_final[i_viz][0]}"); plt.axis('off')
        plt.tight_layout(rect=[0,0,1,0.95]); plt.savefig(f"{output_dir}/7_model_inputs_bordered.png"); plt.show()

    # Step 8: Final Result Visualization
    result_img = image.copy()
    for char, (x,y,w,h) in letters_info_final:
        cv2.rectangle(result_img, (x,y), (x+w,y+h), (0,255,0), 2)
        cv2.putText(result_img, char, (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255,0,0), 2)
    plt.figure(figsize=(10,5)); plt.imshow(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)); plt.title("Final Classification on Image"); plt.savefig(f"{output_dir}/8_final_classification.png"); plt.show()
    return letters_info_final

# --- Example Usage for Detailed Multi-Character Extraction ---
print("\n--- Detailed Multi-Character Image Analysis Example ---")
user_multi_char_image_path = ""  # <<< USER: PROVIDE PATH TO AN IMAGE WITH MULTIPLE CHARACTERS

if not user_multi_char_image_path or not os.path.exists(user_multi_char_image_path):
    print(f"INFO: 'user_multi_char_image_path' ('{user_multi_char_image_path}') is not set or does not exist.")
    print("Skipping detailed multi-character extraction example. Provide a valid image path to run this.")
elif not loaded_inference_model: # Check if a model was successfully loaded in the previous section
    print("ERROR: Model not loaded from single image inference section. Cannot run detailed extraction.")
    print(f"Ensure '{user_model_checkpoint_path}' is a valid model path and was loaded successfully.")
elif not class_labels_for_inference:
     print("ERROR: Class labels not loaded. Cannot run detailed extraction.")
else:
    print(f"Running detailed extraction for image: {user_multi_char_image_path}")
    # We use the 'loaded_inference_model' and 'class_labels_for_inference' from the single image example section.
    # Ensure that 'loaded_inference_model' is the model you intend to use for this detailed analysis.
    detailed_extracted_info = extract_letters_detailed_visualization(
        user_multi_char_image_path, 
        loaded_inference_model, 
        class_labels_for_inference, 
        output_dir="./detailed_extraction_output", # You can change the output directory
        spacing=12 # Adjust spacing around characters if needed
    )
    if detailed_extracted_info:
        print("\n--- Detailed Extraction Results ---")
        for char_info_item in detailed_extracted_info:
            print(f"Character: {char_info_item[0]}, Bounding Box: {char_info_item[1]}")
    else:
        print("No characters were extracted or classified in the detailed analysis.")