# Handwritten Character Recognition - Gradio Application

This notebook launches an interactive Gradio application for handwritten character recognition. You can either draw characters/words directly in the interface or upload an image containing handwritten text. The application will then attempt to segment and recognize the characters using a pre-trained model.

**Prerequisites:**
1.  **Trained Models:** Ensure you have trained models saved in the appropriate checkpoint directories (e.g., `./model_checkpoints/cnn/best_model.pth` or `./model_checkpoints/vgg/best_model.pth`). The training notebook (`training.ipynb`) should produce these.
2.  **Dataset for Class Labels:** The path to the root of the training dataset is needed to derive the class labels (character names). Update `data_root_for_app_labels` if necessary.

## 1. Imports and Setup

In [None]:
# General utilities
import os
import random

# Image processing and display
import matplotlib.pyplot as plt
from PIL import Image
import cv2 # OpenCV for image operations
import numpy as np

# PyTorch essentials
import torch
import torch.nn as nn
import torch.optim as optim # May not be needed directly, but models might reference it if part of saved state
import torch.nn.functional as F
import torchvision.datasets as datasets # For get_class_labels
import torchvision.transforms as transforms
import torchvision.models as models
from pathlib import Path 

# Gradio for the web application
!pip install gradio -q
import gradio as gr
print(f"Gradio version: {gr.__version__}")

## 2. Device Configuration

In [None]:
# Device configuration
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

## 3. Model Architectures

The definitions of the models are needed to load the saved weights. Ensure these match the definitions used during training.

### 3.1. Custom CNNs (`LetterCNN64`, `ImprovedLetterCNN`)

In [None]:
class LetterCNN64(nn.Module):
    def __init__(self, num_classes):
        super(LetterCNN64, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = self.relu4(self.fc1(x))
        x = self.fc2(x)
        return x

class ImprovedLetterCNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedLetterCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.bn_fc1 = nn.BatchNorm1d(1024)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = x.view(-1, 256 * 4 * 4)
        x = self.dropout1(self.relu_fc1(self.bn_fc1(self.fc1(x))))
        x = self.dropout2(self.relu_fc2(self.bn_fc2(self.fc2(x))))
        x = self.fc3(x)
        return x

### 3.2. VGG19 Transfer Learning Model (`VGG19HandwritingModel`)

In [None]:
class VGG19HandwritingModel(nn.Module):
    def __init__(self, num_classes, device, pretrained=True):
        super(VGG19HandwritingModel, self).__init__()
        self.device = device
        vgg19 = models.vgg19_bn(weights=models.VGG19_BN_Weights.IMAGENET1K_V1 if pretrained else None)
        vgg19 = vgg19.to(device)
        self.features = vgg19.features
        if pretrained:
            for param in self.features.parameters():
                param.requires_grad = False
        num_features_output = 512 * 2 * 2 
        self.classifier = nn.Sequential(
            nn.Linear(num_features_output, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        ).to(device)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 4. Application Utilities

In [None]:
def load_model_for_inference(model, checkpoint_path):
    """Loads a model checkpoint for inference."""
    if not os.path.exists(checkpoint_path):
        print(f"ERROR: Checkpoint path {checkpoint_path} does not exist. Cannot load model.")
        return None # Return None if checkpoint not found
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint) # Fallback for older checkpoints
        model.eval()
        print(f"Model loaded successfully from {checkpoint_path} and set to evaluation mode.")
        return model
    except Exception as e:
        print(f"ERROR loading model from {checkpoint_path}: {e}")
        return None

def get_class_labels_from_dir(data_root_for_labels):
    """Gets class labels from the folder names in data_root (ImageFolder structure)."""
    if not os.path.exists(data_root_for_labels):
        print(f"Error: Data root for labels '{data_root_for_labels}' not found.")
        return []
    try:
        temp_dataset = datasets.ImageFolder(root=data_root_for_labels)
        return temp_dataset.classes
    except Exception as e:
        print(f"Error getting class labels from '{data_root_for_labels}': {e}")
        return []

## 5. Gradio Application Setup

In [None]:
# --- Configuration for Gradio App ---
MODEL_PATH_CNN_APP = os.path.join('model_checkpoints', 'cnn', 'best_model.pth') 
MODEL_PATH_VGG_APP = os.path.join('model_checkpoints', 'vgg', 'best_model.pth')
DATA_ROOT_FOR_APP_LABELS = "./datasets/handwritten-english/augmented_images/augmented_images1" # <<< USER: Update if your dataset path is different

app_class_labels = get_class_labels_from_dir(DATA_ROOT_FOR_APP_LABELS)
app_num_classes = len(app_class_labels)

if not app_class_labels:
    print("WARNING: Could not load class labels. Predictions might be incorrect or app might fail.")
    print("Ensure DATA_ROOT_FOR_APP_LABELS points to a valid ImageFolder dataset structure.")
    # Fallback to a generic list if labels can't be loaded, though this is not ideal
    app_class_labels = [str(i) for i in range(62)] # Assuming 62 classes if specific labels are missing
    app_num_classes = 62 

app_model = None
selected_model_name = "ImprovedLetterCNN" # Default choice

# Try loading CNN model first, then VGG if CNN is not found
if os.path.exists(MODEL_PATH_CNN_APP):
    print(f"Attempting to load {selected_model_name} from {MODEL_PATH_CNN_APP}")
    temp_model_cnn = ImprovedLetterCNN(app_num_classes).to(device)
    app_model = load_model_for_inference(temp_model_cnn, MODEL_PATH_CNN_APP)
elif os.path.exists(MODEL_PATH_VGG_APP):
    selected_model_name = "VGG19HandwritingModel"
    print(f"CNN model not found. Attempting to load {selected_model_name} from {MODEL_PATH_VGG_APP}")
    temp_model_vgg = VGG19HandwritingModel(app_num_classes, device, pretrained=False).to(device)
    app_model = load_model_for_inference(temp_model_vgg, MODEL_PATH_VGG_APP)
else:
    print(f"ERROR: No model checkpoint found at {MODEL_PATH_CNN_APP} or {MODEL_PATH_VGG_APP}.")
    print("Please ensure a trained model is available at one of these paths for the app to function.")

if app_model:
    print(f"Successfully loaded {selected_model_name} for the Gradio app.")
else:
    print(f"WARNING: Failed to load any model. The Gradio app might not work correctly.")

In [None]:
def process_image_for_gradio(image_input_np, model_to_use, class_list, spacing=24):
    """
    Processes an image (from sketchpad or upload) for Gradio: segments, preprocesses, and predicts.
    image_input_np: NumPy array (RGB) from Gradio input.
    model_to_use: The pre-loaded PyTorch model.
    class_list: List of class names.
    spacing: Spacing to add around characters before resizing for model input.
    Returns: Tuple (list_of_visualization_images, detected_text_string)
    """
    if image_input_np is None:
        empty_img = np.ones((100,100,3), dtype=np.uint8) * 255
        cv2.putText(empty_img, "No Image", (10,50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1)
        return [("No Image Provided", empty_img)], "No image provided"
    
    if model_to_use is None:
        empty_img = np.ones((100,100,3), dtype=np.uint8) * 255
        cv2.putText(empty_img, "Model Not Loaded", (10,50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
        return [("Model Not Loaded", empty_img)], "ERROR: Model not loaded"

    # Convert RGB NumPy array to BGR for OpenCV
    image_bgr = cv2.cvtColor(image_input_np, cv2.COLOR_RGB2BGR)
    visualization_steps = [("Original Input", image_input_np.copy())]

    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    visualization_steps.append(("Grayscale", cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)))

    # Invert colors if background is dark (common for sketchpad if not inverted there)
    # Simple heuristic: if mean < 128, assume dark background / light text
    if np.mean(gray) < 128:
        gray = 255 - gray
        visualization_steps.append(("Inverted Gray (if needed)", cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)))

    _, binary_for_contours = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    visualization_steps.append(("Binary for Contours", cv2.cvtColor(binary_for_contours, cv2.COLOR_GRAY2RGB)))

    _, binary_for_model = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Black text on white
    visualization_steps.append(("Binary for Model Prep", cv2.cvtColor(binary_for_model, cv2.COLOR_GRAY2RGB)))

    contours, _ = cv2.findContours(binary_for_contours, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contour_img_viz = image_input_np.copy(); cv2.drawContours(contour_img_viz, contours, -1, (0,255,0),2)
    visualization_steps.append((f"All Contours ({len(contours)})", contour_img_viz))

    if not contours: return visualization_steps, "No contours found."

    min_area = 30 
    letter_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > min_area]
    if not letter_contours: return visualization_steps, f"No contours > area {min_area}."
    letter_contours = sorted(letter_contours, key=lambda cnt: cv2.boundingRect(cnt)[0])

    filtered_contour_img_viz = image_input_np.copy(); cv2.drawContours(filtered_contour_img_viz, letter_contours, -1, (255,0,0),2)
    visualization_steps.append((f"Filtered Contours ({len(letter_contours)})", filtered_contour_img_viz))

    detected_text_str = ""
    model_inputs_visual_list = []
    final_preds_on_image = image_input_np.copy()
    model_to_use.eval() # Ensure model is in eval mode

    # Define the transformation for model input (consistent with training)
    gradio_input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x), 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    for i, contour in enumerate(letter_contours):
        x, y, w, h = cv2.boundingRect(contour)
        if w < 5 or h < 5: continue

        letter_roi = binary_for_model[y:y+h, x:x+w] # Use the non-inverted binary image for the ROI
        letter_pil = Image.fromarray(letter_roi).convert('L')

        target_s = 64 - spacing
        if w > h: new_w, new_h = target_s, int((h/w)*target_s)
        else: new_h, new_w = target_s, int((w/h)*target_s)
        
        resized = letter_pil.resize((new_w, new_h), Image.LANCZOS)
        padded = Image.new('L', (target_s, target_s), 255) # White background for padding
        px, py = (target_s - new_w)//2, (target_s - new_h)//2
        padded.paste(resized, (px, py))
        
        final_img_for_model_input = Image.new('L', (64,64), 255)
        spacing_offset = spacing//2
        final_img_for_model_input.paste(padded, (spacing_offset, spacing_offset))
        model_inputs_visual_list.append(np.array(final_img_for_model_input))

        img_tensor = gradio_input_transform(final_img_for_model_input).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model_to_use(img_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            confidence, predicted_idx = torch.max(probabilities, 1)
            char_name = class_list[predicted_idx.item()]
        
        detected_text_str += char_name
        cv2.rectangle(final_preds_on_image, (x,y), (x+w,y+h), (0,0,255), 2)
        cv2.putText(final_preds_on_image, f"{char_name} ({confidence.item():.2f})", (x,y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255),1)

    if model_inputs_visual_list:
        model_inputs_grid = create_image_grid_for_gradio([cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) for img in model_inputs_visual_list], "Processed Segments for Model")
        if model_inputs_grid is not None: visualization_steps.append(("Model Inputs", model_inputs_grid))
    
    visualization_steps.append(("Predictions on Image", final_preds_on_image))
    return visualization_steps, detected_text_str

def create_image_grid_for_gradio(images_list, title):
    """Helper to create a grid of images for Gradio gallery."""
    if not images_list: return None
    cols = 5
    rows = (len(images_list) + cols -1) // cols
    fig_w, fig_h = cols*2.5, rows*2.5
    if os.environ.get('DISPLAY','') == '' and os.name != 'posix': plt.switch_backend('Agg')
        
    fig = plt.figure(figsize=(fig_w, fig_h)); plt.suptitle(title)
    for i, img_data in enumerate(images_list):
        ax = fig.add_subplot(rows, cols, i+1)
        img_title, img = (None, img_data) if not isinstance(img_data, tuple) else img_data
        if img_title: ax.set_title(img_title)
        cmap = 'gray' if len(img.shape)==2 else None
        ax.imshow(img, cmap=cmap); ax.axis('off')
    plt.tight_layout(rect=[0,0,1,0.95 if title else 1])
    
    fig.canvas.draw()
    grid_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    grid_img = grid_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    return grid_img

def launch_gradio_app(model_for_app, class_names_for_app):
    def handle_sketch_input(sketch_np_array):
        if sketch_np_array is None: return [], "Please draw something."
        return process_image_for_gradio(sketch_np_array, model_for_app, class_names_for_app)
    
    def handle_upload_input(uploaded_pil_image):
        if uploaded_pil_image is None: return [], "Please upload an image."
        return process_image_for_gradio(np.array(uploaded_pil_image), model_for_app, class_names_for_app)

    with gr.Blocks(title="Handwritten Character Recognition App") as app_interface:
        gr.Markdown("## Handwritten Character Recognition App")
        gr.Markdown(f"Using Model: **{selected_model_name}** with **{app_num_classes}** classes. Ensure this matches your intent.")
        if model_for_app is None: gr.Markdown("**WARNING: MODEL NOT LOADED. PREDICTIONS WILL FAIL.**")

        with gr.Tab("Draw Text"):
            sketch_input = gr.Sketchpad(label="Draw Here", type="numpy", image_mode="RGB", invert_colors=False, shape=(700,250))
            sketch_button = gr.Button("Recognize Drawing", variant="primary")
        
        with gr.Tab("Upload Image"):
            upload_input = gr.Image(label="Upload Image", type="pil", image_mode="RGB")
            upload_button = gr.Button("Recognize Uploaded Image", variant="primary")

        recognized_text_output = gr.Textbox(label="Recognized Text")
        gallery_output = gr.Gallery(label="Processing Steps & Results", columns=[3], object_fit="contain", height="auto")

        sketch_button.click(fn=handle_sketch_input, inputs=sketch_input, outputs=[gallery_output, recognized_text_output])
        upload_button.click(fn=handle_upload_input, inputs=upload_input, outputs=[gallery_output, recognized_text_output])
        
        gr.Markdown("### Usage Notes:
"+
                    "- **Drawing:** Use your mouse to draw. For multiple characters, leave some space.
"+
                    "- **Uploading:** Black text on a white background is preferred.
"+
                    "- **Segmentation:** The quality of character segmentation can vary. Clear, well-spaced handwriting works best.")
    app_interface.launch(share=True, debug=True)

# Launch the app if a model was loaded
if app_model and app_class_labels:
    launch_gradio_app(app_model, app_class_labels)
else:
    print("ERROR: Gradio app cannot be launched because the model or class labels could not be loaded.")
    print("Please check model paths and data root path for labels.")