In [None]:
# --- Cell 1: Import Libraries and Set Up Environment ---
"""
# Handwritten Character Recognition: Inference Notebook

This notebook demonstrates how to use trained handwritten character recognition models for inference.
It covers loading models, performing single and batch inference, analyzing results, and visualizing predictions.
"""

import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
from PIL import Image
from glob import glob
import json

# PyTorch imports
import torch
import torch.nn.functional as F
from torchvision import transforms

# Import utility modules
from src.models_util import get_model, get_model_info
from src.inference_utils import prepare_single_image, predict_single_image, predict_batch_images
from src.inference_utils import extract_characters_from_image, create_visualization_grid
from src.inference_utils import get_top_k_predictions, benchmark_inference_speed

# For reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

# Device configuration
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Create directories for saving results
os.makedirs("inference_results", exist_ok=True)
os.makedirs("inference_results/single", exist_ok=True)
os.makedirs("inference_results/batch", exist_ok=True)
os.makedirs("inference_results/text", exist_ok=True)



In [None]:
# --- Cell 2: Model Loading Functions ---
"""
## Model Loading Functions

These functions handle loading trained models from checkpoints.
They include options for different model architectures and configurations.
"""

def load_model_checkpoint(checkpoint_path, model_name, num_classes, device=device):
    """
    Load a model from a checkpoint file.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        model_name: Name of the model architecture
        num_classes: Number of output classes
        device: Device to load the model on
        
    Returns:
        torch.nn.Module: Loaded model
    """
    print(f"Loading model checkpoint from: {checkpoint_path}")
    
    if not os.path.exists(checkpoint_path):
        print(f"ERROR: Checkpoint file '{checkpoint_path}' not found")
        return None
    
    try:
        # Initialize the model architecture
        model = get_model(model_name, num_classes, device, pretrained=False)
        
        # Load the checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Extract state dict
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            # Handle case where checkpoint is just the state dict
            state_dict = checkpoint
        
        # Load state dict into model
        model.load_state_dict(state_dict)
        
        # Set model to evaluation mode
        model.eval()
        
        print(f"Model loaded successfully with {sum(p.numel() for p in model.parameters())} parameters")
        
        # Extract metadata if available
        metadata = {}
        for key in ['epoch', 'accuracy', 'val_acc', 'loss']:
            if key in checkpoint:
                metadata[key] = checkpoint[key]
        
        if metadata:
            print(f"Checkpoint metadata:")
            for key, value in metadata.items():
                print(f"  {key}: {value}")
        
        return model
    
    except Exception as e:
        print(f"Error loading model checkpoint: {e}")
        import traceback
        traceback.print_exc()
        return None

def load_class_names(class_names_path):
    """
    Load class names from a text file.
    
    Args:
        class_names_path: Path to the class names file
        
    Returns:
        list: List of class names
    """
    print(f"Loading class names from: {class_names_path}")
    
    if not os.path.exists(class_names_path):
        print(f"ERROR: Class names file '{class_names_path}' not found")
        return None
    
    try:
        with open(class_names_path, 'r') as f:
            class_names = [line.strip() for line in f if line.strip()]
        
        print(f"Loaded {len(class_names)} class names")
        return class_names
    
    except Exception as e:
        print(f"Error loading class names: {e}")
        return None

def find_available_models(model_checkpoints_dir="model_checkpoints"):
    """
    Find available trained models in the checkpoints directory.
    
    Args:
        model_checkpoints_dir: Directory containing model checkpoints
        
    Returns:
        dict: Dictionary of available models
    """
    print(f"Searching for available models in: {model_checkpoints_dir}")
    
    if not os.path.exists(model_checkpoints_dir):
        print(f"ERROR: Model checkpoints directory '{model_checkpoints_dir}' not found")
        return {}
    
    available_models = {}
    
    # List subdirectories (each should be a model type)
    model_dirs = [d for d in os.listdir(model_checkpoints_dir) 
                if os.path.isdir(os.path.join(model_checkpoints_dir, d))]
    
    for model_dir in model_dirs:
        model_path = os.path.join(model_checkpoints_dir, model_dir)
        
        # Check for model checkpoints
        checkpoint_files = []
        for ext in ['*.pth', '*.pt']:
            checkpoint_files.extend(glob(os.path.join(model_path, ext)))
        
        if checkpoint_files:
            # Check for class_names.txt
            class_names_path = os.path.join(model_path, 'class_names.txt')
            has_class_names = os.path.exists(class_names_path)
            
            # Check for model_info.txt
            model_info_path = os.path.join(model_path, 'model_info.txt')
            has_model_info = os.path.exists(model_info_path)
            
            available_models[model_dir] = {
                'checkpoint_files': sorted(checkpoint_files),
                'class_names_path': class_names_path if has_class_names else None,
                'model_info_path': model_info_path if has_model_info else None,
                'has_class_names': has_class_names,
                'has_model_info': has_model_info
            }
    
    if available_models:
        print(f"Found {len(available_models)} available models:")
        for model_name, info in available_models.items():
            print(f"  - {model_name}:")
            print(f"    - Checkpoints: {len(info['checkpoint_files'])}")
            print(f"    - Has class names: {info['has_class_names']}")
            print(f"    - Has model info: {info['has_model_info']}")
    else:
        print("No trained models found.")
    
    return available_models

def load_inference_model(model_name, checkpoint_type="best", model_checkpoints_dir="model_checkpoints"):
    """
    Load a model for inference.
    
    Args:
        model_name: Name of the model to load
        checkpoint_type: Type of checkpoint to load ('best', 'final', or specific path)
        model_checkpoints_dir: Directory containing model checkpoints
        
    Returns:
        tuple: (model, class_names)
    """
    print(f"Loading model '{model_name}' for inference...")
    
    # Find available models
    available_models = find_available_models(model_checkpoints_dir)
    
    if model_name not in available_models:
        print(f"ERROR: Model '{model_name}' not found in available models")
        return None, None
    
    model_info = available_models[model_name]
    
    # Load class names
    if not model_info['has_class_names']:
        print(f"WARNING: No class_names.txt found for model '{model_name}'")
        print("Will use default class names (0-9, A-Z, a-z)")
        class_names = [str(i) for i in range(10)] + \
                     [chr(i) for i in range(65, 91)] + \
                     [chr(i) for i in range(97, 123)]
    else:
        class_names = load_class_names(model_info['class_names_path'])
        
    if class_names is None:
        print("ERROR: Failed to load class names")
        return None, None
    
    num_classes = len(class_names)
    
    # Determine checkpoint path
    checkpoint_path = None
    
    if checkpoint_type == "best":
        # Look for best_model.pth or similar
        best_paths = [cp for cp in model_info['checkpoint_files'] if 'best' in os.path.basename(cp).lower()]
        if best_paths:
            checkpoint_path = best_paths[0]
        else:
            print("WARNING: No 'best' checkpoint found, falling back to first available checkpoint")
            checkpoint_path = model_info['checkpoint_files'][0]
    
    elif checkpoint_type == "final":
        # Look for final_model.pth or similar
        final_paths = [cp for cp in model_info['checkpoint_files'] if 'final' in os.path.basename(cp).lower()]
        if final_paths:
            checkpoint_path = final_paths[0]
        else:
            print("WARNING: No 'final' checkpoint found, falling back to first available checkpoint")
            checkpoint_path = model_info['checkpoint_files'][0]
    
    elif os.path.exists(checkpoint_type):
        # Use the provided path directly
        checkpoint_path = checkpoint_type
    
    else:
        print(f"ERROR: Invalid checkpoint type '{checkpoint_type}'")
        return None, None
    
    # Load the model
    model = load_model_checkpoint(checkpoint_path, model_name, num_classes, device)
    
    if model is None:
        print("ERROR: Failed to load model")
        return None, None
    
    print(f"Successfully loaded model '{model_name}' with {num_classes} classes")
    return model, class_names



In [None]:
# --- Cell 3: Single Image Inference ---
"""
## Single Image Inference

Perform inference on a single image and visualize the results.
This section includes functions for loading images, preprocessing them, and getting predictions.
"""

def infer_single_image(model, class_names, image_path, save_dir="inference_results/single"):
    """
    Perform inference on a single image and visualize the results.
    
    Args:
        model: Trained model
        class_names: List of class names
        image_path: Path to the input image
        save_dir: Directory to save the results
        
    Returns:
        dict: Prediction results
    """
    print(f"Running inference on image: {image_path}")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Image file '{image_path}' not found")
        return None
    
    try:
        # Predict using the utility function
        result = predict_single_image(
            model=model,
            image_path=image_path,
            class_names=class_names,
            device=device,
            normalization_type='imagenet',
            return_probabilities=True
        )
        
        # Display the prediction
        img = Image.open(image_path)
        
        # Create visualization
        plt.figure(figsize=(12, 6))
        
        # Original image
        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.title("Input Image")
        plt.axis('off')
        
        # Prediction results
        plt.subplot(1, 2, 2)
        
        # Get top 5 predictions
        probs = result['class_probabilities']
        top_k = 5
        
        # Sort probabilities and get top k
        sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:top_k]
        
        classes = [cls for cls, _ in sorted_probs]
        values = [val for _, val in sorted_probs]
        
        plt.barh(range(len(classes)), values, color='skyblue')
        plt.yticks(range(len(classes)), classes)
        plt.xlim(0, 1)
        plt.xlabel("Confidence")
        plt.title(f"Prediction: {result['predicted_class']} ({result['confidence']:.2%})")
        plt.grid(axis='x', linestyle='--', alpha=0.7)
        
        # Add values on the bars
        for i, v in enumerate(values):
            plt.text(v + 0.01, i, f"{v:.2%}", va='center')
        
        plt.tight_layout()
        
        # Save the visualization
        os.makedirs(save_dir, exist_ok=True)
        basename = os.path.basename(image_path)
        save_path = os.path.join(save_dir, f"prediction_{basename}")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Prediction saved to {save_path}")
        plt.show()
        
        # Print the prediction
        print(f"Prediction: {result['predicted_class']} with confidence {result['confidence']:.2%}")
        
        return result
    
    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        return None

def visualize_activation_map(model, image_path, class_names, save_dir="inference_results/single"):
    """
    Visualize the activation map for a single image.
    
    Args:
        model: Trained model (must support hooks for grad-CAM)
        image_path: Path to the input image
        class_names: List of class names
        save_dir: Directory to save the results
        
    Returns:
        dict: Visualization results
    """
    print(f"Generating activation map for image: {image_path}")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Image file '{image_path}' not found")
        return None
    
    try:
        # Load and preprocess the image
        img = Image.open(image_path).convert('L')
        img_tensor = prepare_single_image(image_path, normalization_type='imagenet', device=device)
        
        # Check if the model has features attribute (necessary for grad-CAM)
        if not hasattr(model, 'features'):
            print("WARNING: Model does not have 'features' attribute, cannot generate activation map")
            return None
        
        # Get the prediction
        model.eval()
        with torch.no_grad():
            output = model(img_tensor)
            probabilities = F.softmax(output, dim=1)
            confidence, predicted_idx = torch.max(probabilities, 1)
            predicted_class = class_names[predicted_idx.item()]
        
        # Generate activation map using grad-CAM
        # We'll use the last convolutional layer in the features
        target_layer = None
        for module in reversed(list(model.features)):
            if isinstance(module, torch.nn.Conv2d):
                target_layer = module
                break
        
        if target_layer is None:
            print("WARNING: Could not find a convolutional layer in the model")
            return None
        
        # Set up hooks
        feature_maps = None
        gradients = None
        
        def save_feature_maps(module, input, output):
            nonlocal feature_maps
            feature_maps = output.detach()
        
        def save_gradients(module, grad_input, grad_output):
            nonlocal gradients
            gradients = grad_output[0].detach()
        
        # Register hooks
        handle_forward = target_layer.register_forward_hook(save_feature_maps)
        handle_backward = target_layer.register_backward_hook(save_gradients)
        
        # Forward pass with gradients
        model.zero_grad()
        output = model(img_tensor)
        
        # Get the gradient of the output with respect to the predicted class
        one_hot = torch.zeros_like(output)
        one_hot[0, predicted_idx.item()] = 1
        
        output.backward(gradient=one_hot)
        
        # Remove hooks
        handle_forward.remove()
        handle_backward.remove()
        
        # Calculate the weight of each feature map
        weights = torch.mean(gradients, dim=(2, 3))
        
        # Generate the class activation map
        batch_size, num_channels, height, width = feature_maps.size()
        cam = torch.zeros(height, width, dtype=torch.float32, device=device)
        
        for i, w in enumerate(weights[0]):
            cam += w * feature_maps[0, i]
        
        # Normalize the CAM
        cam = F.relu(cam)
        cam = cam - torch.min(cam)
        cam = cam / (torch.max(cam) + 1e-10)
        
        # Resize to match the original image
        cam = cam.cpu().numpy()
        cam = cv2.resize(cam, (img.width, img.height))
        
        # Convert image to numpy
        img_np = np.array(img)
        
        # Create heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Combine original image and heatmap
        img_rgb = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
        superimposed = cv2.addWeighted(img_rgb, 0.6, heatmap, 0.4, 0)
        
        # Visualize
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.imshow(img, cmap='gray')
        plt.title("Original Image")
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(cam, cmap='jet')
        plt.title("Activation Map")
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(superimposed)
        plt.title(f"Prediction: {predicted_class} ({confidence.item():.2%})")
        plt.axis('off')
        
        plt.tight_layout()
        
        # Save the visualization
        os.makedirs(save_dir, exist_ok=True)
        basename = os.path.basename(image_path)
        save_path = os.path.join(save_dir, f"activation_{basename}")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Activation map saved to {save_path}")
        plt.show()
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence.item(),
            'cam': cam,
            'superimposed': superimposed
        }
    
    except Exception as e:
        print(f"Error generating activation map: {e}")
        import traceback
        traceback.print_exc()
        return None



In [None]:
# --- Cell 4: Batch Inference ---
"""
## Batch Inference

Process multiple images in batch mode for efficient inference.
This section includes functions for folder processing and result aggregation.
"""

def infer_batch_images(model, class_names, image_folder, save_dir="inference_results/batch", 
                     batch_size=16, visualize=True):
    """
    Perform inference on multiple images in a folder.
    
    Args:
        model: Trained model
        class_names: List of class names
        image_folder: Path to the folder containing images
        save_dir: Directory to save the results
        batch_size: Batch size for processing
        visualize: Whether to visualize the results
        
    Returns:
        list: Prediction results for all images
    """
    print(f"Running batch inference on images in: {image_folder}")
    
    if not os.path.exists(image_folder):
        print(f"ERROR: Image folder '{image_folder}' not found")
        return None
    
    # Find all image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.gif']
    image_paths = []
    
    for ext in image_extensions:
        image_paths.extend(glob(os.path.join(image_folder, f"*{ext}")))
        image_paths.extend(glob(os.path.join(image_folder, f"*{ext.upper()}")))
    
    if not image_paths:
        print(f"ERROR: No image files found in '{image_folder}'")
        return None
    
    print(f"Found {len(image_paths)} images for batch inference")
    
    try:
        # Predict using the utility function
        results = predict_batch_images(
            model=model,
            image_paths=image_paths,
            class_names=class_names,
            device=device,
            normalization_type='imagenet',
            batch_size=batch_size
        )
        
        # Create a summary of results
        summary = {
            'total_images': len(image_paths),
            'predicted_classes': {},
            'confidence_stats': {
                'min': min(r['confidence'] for r in results),
                'max': max(r['confidence'] for r in results),
                'avg': sum(r['confidence'] for r in results) / len(results)
            }
        }
        
        # Count predicted classes
        for result in results:
            pred_class = result['predicted_class']
            if pred_class in summary['predicted_classes']:
                summary['predicted_classes'][pred_class] += 1
            else:
                summary['predicted_classes'][pred_class] = 1
        
        # Sort class counts in descending order
        summary['predicted_classes'] = dict(sorted(
            summary['predicted_classes'].items(), 
            key=lambda x: x[1], 
            reverse=True
        ))
        
        # Save results to JSON
        os.makedirs(save_dir, exist_ok=True)
        results_file = os.path.join(save_dir, 'batch_results.json')
        
        with open(results_file, 'w') as f:
            json.dump({
                'summary': summary,
                'predictions': results
            }, f, indent=2)
        
        print(f"Batch results saved to {results_file}")
        
        # Visualize results if requested
        if visualize:
            # Class distribution
            plt.figure(figsize=(10, 6))
            
            classes = list(summary['predicted_classes'].keys())
            counts = list(summary['predicted_classes'].values())
            
            # Limit to top 20 classes if there are too many
            if len(classes) > 20:
                classes = classes[:20]
                counts = counts[:20]
                plt.title("Top 20 Predicted Classes")
            else:
                plt.title("Predicted Class Distribution")
            
            plt.bar(classes, counts, color='skyblue')
            plt.xticks(rotation=45, ha='right')
            plt.ylabel("Count")
            plt.xlabel("Class")
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            plt.tight_layout()
            
            # Save the visualization
            class_dist_file = os.path.join(save_dir, 'class_distribution.png')
            plt.savefig(class_dist_file, dpi=300, bbox_inches='tight')
            plt.show()
            
            # Confidence distribution
            plt.figure(figsize=(10, 6))
            
            confidences = [r['confidence'] for r in results]
            plt.hist(confidences, bins=20, color='skyblue', edgecolor='black')
            plt.axvline(x=summary['confidence_stats']['avg'], color='red', linestyle='--', 
                      label=f"Avg: {summary['confidence_stats']['avg']:.2%}")
            plt.xlabel("Confidence")
            plt.ylabel("Count")
            plt.title("Confidence Distribution")
            plt.grid(linestyle='--', alpha=0.7)
            plt.legend()
            plt.tight_layout()
            
            # Save the visualization
            conf_dist_file = os.path.join(save_dir, 'confidence_distribution.png')
            plt.savefig(conf_dist_file, dpi=300, bbox_inches='tight')
            plt.show()
            
            # Sample predictions
            sample_count = min(5, len(results))
            sample_indices = random.sample(range(len(results)), sample_count)
            
            plt.figure(figsize=(15, 3 * sample_count))
            
            for i, idx in enumerate(sample_indices):
                result = results[idx]
                img_path = result['image_path']
                img = Image.open(img_path)
                
                plt.subplot(sample_count, 2, i*2 + 1)
                plt.imshow(img)
                plt.title(f"Sample {i+1}")
                plt.axis('off')
                
                plt.subplot(sample_count, 2, i*2 + 2)
                plt.text(0.5, 0.5, 
                       f"Predicted: {result['predicted_class']}\nConfidence: {result['confidence']:.2%}", 
                       ha='center', va='center', fontsize=12)
                plt.axis('off')
            
            plt.tight_layout()
            
            # Save the visualization
            samples_file = os.path.join(save_dir, 'sample_predictions.png')
            plt.savefig(samples_file, dpi=300, bbox_inches='tight')
            plt.show()
        
        return results
    
    except Exception as e:
        print(f"Error during batch inference: {e}")
        import traceback
        traceback.print_exc()
        return None

def analyze_batch_results(results_file, save_dir="inference_results/batch"):
    """
    Analyze the results of batch inference.
    
    Args:
        results_file: Path to the JSON file containing batch results
        save_dir: Directory to save the analysis results
        
    Returns:
        dict: Analysis results
    """
    print(f"Analyzing batch inference results from: {results_file}")
    
    if not os.path.exists(results_file):
        print(f"ERROR: Results file '{results_file}' not found")
        return None
    
    try:
        # Load results
        with open(results_file, 'r') as f:
            data = json.load(f)
        
        summary = data['summary']
        predictions = data['predictions']
        
        print(f"Loaded results for {summary['total_images']} images")
        
        # Additional analysis
        analysis = {
            'total_images': summary['total_images'],
            'class_distribution': summary['predicted_classes'],
            'confidence_stats': summary['confidence_stats'],
            'high_confidence': [],
            'low_confidence': [],
            'confidence_by_class': {}
        }
        
        # Find high and low confidence predictions
        confidence_threshold_high = 0.9
        confidence_threshold_low = 0.5
        
        for pred in predictions:
            # High confidence
            if pred['confidence'] >= confidence_threshold_high:
                analysis['high_confidence'].append(pred)
            
            # Low confidence
            if pred['confidence'] <= confidence_threshold_low:
                analysis['low_confidence'].append(pred)
            
            # Confidence by class
            pred_class = pred['predicted_class']
            if pred_class not in analysis['confidence_by_class']:
                analysis['confidence_by_class'][pred_class] = []
            
            analysis['confidence_by_class'][pred_class].append(pred['confidence'])
        
        # Calculate average confidence by class
        analysis['avg_confidence_by_class'] = {}
        for cls, confs in analysis['confidence_by_class'].items():
            analysis['avg_confidence_by_class'][cls] = sum(confs) / len(confs)
        
        # Sort classes by average confidence
        analysis['avg_confidence_by_class'] = dict(sorted(
            analysis['avg_confidence_by_class'].items(),
            key=lambda x: x[1],
            reverse=True
        ))
        
        # Create visualizations
        # 1. Average confidence by class
        plt.figure(figsize=(12, 6))
        
        classes = list(analysis['avg_confidence_by_class'].keys())
        avg_confs = list(analysis['avg_confidence_by_class'].values())
        
        # Limit to top 20 and bottom 10 classes if there are too many
        if len(classes) > 30:
            top_classes = classes[:20]
            top_confs = avg_confs[:20]
            
            bottom_classes = classes[-10:]
            bottom_confs = avg_confs[-10:]
            
            plt.subplot(1, 2, 1)
            plt.bar(top_classes, top_confs, color='green')
            plt.xticks(rotation=45, ha='right')
            plt.ylabel("Average Confidence")
            plt.title("Top 20 Classes by Confidence")
            plt.ylim(0, 1)
            plt.grid(axis='y', linestyle='--', alpha=0.7)
            
            plt.subplot(1, 2, 2)
            plt.bar(bottom_classes, bottom_confs, color='red')
            plt.xticks(rotation=45, ha='right')
            plt.ylabel("Average Confidence")
            plt.title("Bottom 10 Classes by Confidence")
            plt.ylim(0, 1)
            plt.grid(axis='y', linestyle='--', alpha=0.7)
        else:
            plt.bar(classes, avg_confs, color='skyblue')
            plt.xticks(rotation=45, ha='right')
            plt.ylabel("Average Confidence")
            plt.title("Average Confidence by Class")
            plt.ylim(0, 1)
            plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        
        # Save the visualization
        avg_conf_file = os.path.join(save_dir, 'avg_confidence_by_class.png')
        plt.savefig(avg_conf_file, dpi=300, bbox_inches='tight')
        plt.show()
        
        # 2. Sample high and low confidence predictions
        sample_high = min(3, len(analysis['high_confidence']))
        sample_low = min(3, len(analysis['low_confidence']))
        
        if sample_high > 0 or sample_low > 0:
            plt.figure(figsize=(15, 3 * (sample_high + sample_low)))
            
            # High confidence samples
            for i in range(sample_high):
                pred = analysis['high_confidence'][i]
                img_path = pred['image_path']
                img = Image.open(img_path)
                
                plt.subplot(sample_high + sample_low, 2, i*2 + 1)
                plt.imshow(img)
                plt.title(f"High Confidence Sample {i+1}")
                plt.axis('off')
                
                plt.subplot(sample_high + sample_low, 2, i*2 + 2)
                plt.text(0.5, 0.5,
                       f"Predicted: {pred['predicted_class']}\nConfidence: {pred['confidence']:.2%}",
                       ha='center', va='center', fontsize=12)
                plt.axis('off')
            
            # Low confidence samples
            for i in range(sample_low):
                pred = analysis['low_confidence'][i]
                img_path = pred['image_path']
                img = Image.open(img_path)
                
                plt.subplot(sample_high + sample_low, 2, (sample_high + i)*2 + 1)
                plt.imshow(img)
                plt.title(f"Low Confidence Sample {i+1}")
                plt.axis('off')
                
                plt.subplot(sample_high + sample_low, 2, (sample_high + i)*2 + 2)
                plt.text(0.5, 0.5,
                       f"Predicted: {pred['predicted_class']}\nConfidence: {pred['confidence']:.2%}",
                       ha='center', va='center', fontsize=12)
                plt.axis('off')
            
            plt.tight_layout()
            
            # Save the visualization
            conf_samples_file = os.path.join(save_dir, 'confidence_samples.png')
            plt.savefig(conf_samples_file, dpi=300, bbox_inches='tight')
            plt.show()
        
        return analysis
    
    except Exception as e:
        print(f"Error analyzing batch results: {e}")
        import traceback
        traceback.print_exc()
        return None



In [None]:
# --- Cell 5: Handwritten Text Recognition ---
"""
## Handwritten Text Recognition

Recognize characters in handwritten text images.
This section includes functions for character segmentation and recognition.
"""

def recognize_handwritten_text(model, class_names, image_path, save_dir="inference_results/text"):
    """
    Recognize characters in a handwritten text image.
    
    Args:
        model: Trained model
        class_names: List of class names
        image_path: Path to the handwritten text image
        save_dir: Directory to save the results
        
    Returns:
        tuple: (detected_text, character_info, visualization_images)
    """
    print(f"Recognizing handwritten text in image: {image_path}")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Image file '{image_path}' not found")
        return None, None, None
    
    try:
        # Extract and classify characters using the utility function
        detected_text, character_info, visualization_images = extract_characters_from_image(
            image_path=image_path,
            model=model,
            class_names=class_names,
            device=device,
            spacing=24,
            min_area=50,
            normalization_type='imagenet',
            visualize=True,
            output_dir=os.path.join(save_dir, os.path.basename(image_path).split('.')[0])
        )
        
        # Create a complete visualization
        if visualization_images:
            # Create a grid of images
            fig = create_visualization_grid(
                visualization_images,
                grid_cols=2,
                figsize=(15, 5 * ((len(visualization_images) + 1) // 2)),
                save_path=os.path.join(save_dir, f"recognition_{os.path.basename(image_path)}")
            )
            
            plt.show()
        
        print(f"Recognized text: {detected_text}")
        
        return detected_text, character_info, visualization_images
    
    except Exception as e:
        print(f"Error recognizing handwritten text: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

def recognize_multiple_text_images(model, class_names, image_folder, save_dir="inference_results/text"):
    """
    Recognize characters in multiple handwritten text images.
    
    Args:
        model: Trained model
        class_names: List of class names
        image_folder: Path to the folder containing text images
        save_dir: Directory to save the results
        
    Returns:
        dict: Recognition results for all images
    """
    print(f"Recognizing handwritten text in images from: {image_folder}")
    
    if not os.path.exists(image_folder):
        print(f"ERROR: Image folder '{image_folder}' not found")
        return None
    
    # Find all image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.gif']
    image_paths = []
    
    for ext in image_extensions:
        image_paths.extend(glob(os.path.join(image_folder, f"*{ext}")))
        image_paths.extend(glob(os.path.join(image_folder, f"*{ext.upper()}")))
    
    if not image_paths:
        print(f"ERROR: No image files found in '{image_folder}'")
        return None
    
    print(f"Found {len(image_paths)} images for text recognition")
    
    results = {}
    
    for img_path in image_paths:
        print(f"Processing: {os.path.basename(img_path)}")
        
        detected_text, character_info, _ = recognize_handwritten_text(
            model=model,
            class_names=class_names,
            image_path=img_path,
            save_dir=save_dir
        )
        
        if detected_text is not None:
            results[img_path] = {
                'text': detected_text,
                'characters': len(character_info),
                'character_info': character_info
            }
    
    # Save results to JSON
    os.makedirs(save_dir, exist_ok=True)
    results_file = os.path.join(save_dir, 'text_recognition_results.json')
    
    with open(results_file, 'w') as f:
        # Convert character_info to serializable format
        serializable_results = {}
        for path, info in results.items():
            char_info = []
            for char_data in info['character_info']:
                char_info.append({
                    'character': char_data['character'],
                    'confidence': char_data['confidence'],
                    'bbox': list(char_data['bbox']),
                    'index': char_data['index']
                })
            
            serializable_results[path] = {
                'text': info['text'],
                'characters': info['characters'],
                'character_info': char_info
            }
        
        json.dump(serializable_results, f, indent=2)
    
    print(f"Text recognition results saved to {results_file}")
    
    return results



In [None]:
# --- Cell 6: Model Analysis and Benchmarking ---
"""
## Model Analysis and Benchmarking

Analyze model performance and benchmark inference speed.
This section includes functions for visualizing model confidence and measuring inference time.
"""

def visualize_model_confidence(model, class_names, image_path, save_dir="inference_results/single"):
    """
    Visualize model confidence across all classes for a single image.
    
    Args:
        model: Trained model
        class_names: List of class names
        image_path: Path to the input image
        save_dir: Directory to save the results
        
    Returns:
        dict: Confidence scores for all classes
    """
    print(f"Visualizing model confidence for image: {image_path}")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Image file '{image_path}' not found")
        return None
    
    try:
        # Load and preprocess the image
        img_tensor = prepare_single_image(image_path, normalization_type='imagenet', device=device)
        
        # Get top-k predictions
        model.eval()
        top_predictions = get_top_k_predictions(
            model=model,
            image_tensor=img_tensor,
            class_names=class_names,
            k=len(class_names),  # Get all classes
            device=device
        )
        
        # Sort predictions by confidence
        sorted_predictions = sorted(top_predictions, key=lambda x: x['confidence'], reverse=True)
        
        # Display the image
        img = Image.open(image_path)
        
        plt.figure(figsize=(15, 8))
        
        # Original image
        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.title("Input Image")
        plt.axis('off')
        
        # Confidence distribution
        plt.subplot(1, 2, 2)
        
        # Get top 10 and bottom 5 predictions for clarity
        top_10 = sorted_predictions[:10]
        bottom_5 = sorted_predictions[-5:]
        
        # Create a combined list
        display_preds = top_10 + bottom_5
        
        # Remove duplicates if any
        display_preds = [dict(t) for t in {tuple(d.items()) for d in display_preds}]
        
        # Sort again
        display_preds = sorted(display_preds, key=lambda x: x['confidence'], reverse=True)
        
        # Extract classes and confidences
        classes = [pred['class'] for pred in display_preds]
        confidences = [pred['confidence'] for pred in display_preds]
        
        # Create a colormap (green for high confidence, red for low)
        colors = ['green' if i < 3 else 'skyblue' if i < 10 else 'red' for i in range(len(display_preds))]
        
        plt.barh(range(len(classes)), confidences, color=colors)
        plt.yticks(range(len(classes)), classes)
        plt.xlim(0, 1)
        plt.xlabel("Confidence")
        plt.title("Model Confidence Distribution")
        plt.grid(axis='x', linestyle='--', alpha=0.7)
        
        # Add values on the bars
        for i, v in enumerate(confidences):
            plt.text(v + 0.01, i, f"{v:.2%}", va='center')
        
        plt.tight_layout()
        
        # Save the visualization
        os.makedirs(save_dir, exist_ok=True)
        basename = os.path.basename(image_path)
        save_path = os.path.join(save_dir, f"confidence_{basename}")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Confidence visualization saved to {save_path}")
        plt.show()
        
        # Return all confidence scores
        confidence_scores = {pred['class']: pred['confidence'] for pred in sorted_predictions}
        return confidence_scores
    
    except Exception as e:
        print(f"Error visualizing model confidence: {e}")
        import traceback
        traceback.print_exc()
        return None

def benchmark_model_performance(model, image_path, class_names, device=device, 
                              num_iterations=100, save_dir="inference_results"):
    """
    Benchmark model inference speed.
    
    Args:
        model: Trained model
        image_path: Path to a sample image for benchmarking
        class_names: List of class names
        device: Device to run the benchmark on
        num_iterations: Number of inference iterations
        save_dir: Directory to save the results
        
    Returns:
        dict: Benchmark results
    """
    print(f"Benchmarking model performance on device: {device}")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Sample image '{image_path}' not found")
        return None
    
    try:
        # Run the benchmark using the utility function
        benchmark_results = benchmark_inference_speed(
            model=model,
            sample_image_path=image_path,
            class_names=class_names,
            device=device,
            num_iterations=num_iterations,
            warmup_iterations=10
        )
        
        # Print results
        print(f"Benchmark results:")
        print(f"  Total time: {benchmark_results['total_time']:.4f} seconds")
        print(f"  Average time per inference: {benchmark_results['avg_time_per_inference']*1000:.2f} ms")
        print(f"  Frames per second (FPS): {benchmark_results['fps']:.2f}")
        
        # Create visualization
        plt.figure(figsize=(10, 6))
        
        plt.bar(['Time per inference (ms)', 'FPS'], 
              [benchmark_results['avg_time_per_inference']*1000, benchmark_results['fps']],
              color=['skyblue', 'green'])
        
        plt.ylabel("Value")
        plt.title(f"Model Performance on {device} ({num_iterations} iterations)")
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add values on the bars
        plt.text(0, benchmark_results['avg_time_per_inference']*1000 + 1, 
               f"{benchmark_results['avg_time_per_inference']*1000:.2f} ms",
               ha='center', va='bottom')
        
        plt.text(1, benchmark_results['fps'] + 1, 
               f"{benchmark_results['fps']:.2f} FPS",
               ha='center', va='bottom')
        
        plt.tight_layout()
        
        # Save the visualization
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"benchmark_{device}.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Benchmark visualization saved to {save_path}")
        plt.show()
        
        return benchmark_results
    
    except Exception as e:
        print(f"Error benchmarking model: {e}")
        import traceback
        traceback.print_exc()
        return None

def compare_models_performance(models_configs, image_path, save_dir="inference_results"):
    """
    Compare the performance of multiple models.
    
    Args:
        models_configs: List of model configurations
            [{'name': 'model_name', 'model': model_instance, 'class_names': class_names}, ...]
        image_path: Path to a sample image for benchmarking
        save_dir: Directory to save the results
        
    Returns:
        dict: Comparison results
    """
    print(f"Comparing performance of {len(models_configs)} models")
    
    if not os.path.exists(image_path):
        print(f"ERROR: Sample image '{image_path}' not found")
        return None
    
    try:
        comparison_results = {}
        
        for config in models_configs:
            model_name = config['name']
            model = config['model']
            class_names = config['class_names']
            
            print(f"\nBenchmarking model: {model_name}")
            
            # Run the benchmark
            results = benchmark_inference_speed(
                model=model,
                sample_image_path=image_path,
                class_names=class_names,
                device=device,
                num_iterations=50,
                warmup_iterations=5
            )
            
            comparison_results[model_name] = results
            
            print(f"  Average time per inference: {results['avg_time_per_inference']*1000:.2f} ms")
            print(f"  Frames per second (FPS): {results['fps']:.2f}")
        
        # Create visualization
        plt.figure(figsize=(12, 6))
        
        # Prepare data for plotting
        model_names = list(comparison_results.keys())
        avg_times = [results['avg_time_per_inference']*1000 for results in comparison_results.values()]
        fps_values = [results['fps'] for results in comparison_results.values()]
        
        # Create a subplot for average inference time
        plt.subplot(1, 2, 1)
        bars = plt.bar(model_names, avg_times, color='skyblue')
        plt.ylabel("Time (ms)")
        plt.title("Average Inference Time")
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add values on the bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                   f"{height:.2f} ms", ha='center', va='bottom')
        
        # Create a subplot for FPS
        plt.subplot(1, 2, 2)
        bars = plt.bar(model_names, fps_values, color='green')
        plt.ylabel("FPS")
        plt.title("Frames Per Second")
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add values on the bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                   f"{height:.2f}", ha='center', va='bottom')
        
        plt.tight_layout()
        
        # Save the visualization
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, "model_comparison.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Model comparison saved to {save_path}")
        plt.show()
        
        return comparison_results
    
    except Exception as e:
        print(f"Error comparing models: {e}")
        import traceback
        traceback.print_exc()
        return None



In [None]:
# --- Cell 7: Complete Inference Pipeline ---
"""
## Complete Inference Pipeline

This section demonstrates a complete inference pipeline from model loading to prediction.
Follow this example to perform inference on your own images.
"""

def run_complete_inference_pipeline(model_name, checkpoint_type="best", 
                                  image_path=None, image_folder=None, 
                                  model_checkpoints_dir="model_checkpoints",
                                  save_dir="inference_results"):
    """
    Run a complete inference pipeline from model loading to prediction.
    
    Args:
        model_name: Name of the model to use
        checkpoint_type: Type of checkpoint to use ('best', 'final', or specific path)
        image_path: Path to a single image for inference (optional)
        image_folder: Path to a folder of images for batch inference (optional)
        model_checkpoints_dir: Directory containing model checkpoints
        save_dir: Directory to save the results
        
    Returns:
        dict: Inference results
    """
    print(f"Running complete inference pipeline with model: {model_name}")
    
    # Step 1: Load the model and class names
    print("\nStep 1: Loading model and class names...")
    model, class_names = load_inference_model(model_name, checkpoint_type, model_checkpoints_dir)
    
    if model is None or class_names is None:
        print("Failed to load model or class names. Aborting inference.")
        return None
    
    results = {}
    
    # Step 2: Single image inference (if provided)
    if image_path is not None:
        print(f"\nStep 2: Running inference on single image: {image_path}")
        
        if not os.path.exists(image_path):
            print(f"ERROR: Image file '{image_path}' not found")
        else:
            # Perform inference
            single_result = infer_single_image(
                model=model,
                class_names=class_names,
                image_path=image_path,
                save_dir=os.path.join(save_dir, "single")
            )
            
            if single_result is not None:
                results['single_image'] = single_result
                
                # Visualize model confidence
                confidence_scores = visualize_model_confidence(
                    model=model,
                    class_names=class_names,
                    image_path=image_path,
                    save_dir=os.path.join(save_dir, "single")
                )
                
                if confidence_scores is not None:
                    results['confidence_scores'] = confidence_scores
    
    # Step 3: Batch inference (if folder provided)
    if image_folder is not None:
        print(f"\nStep 3: Running batch inference on images in: {image_folder}")
        
        if not os.path.exists(image_folder):
            print(f"ERROR: Image folder '{image_folder}' not found")
        else:
            # Perform batch inference
            batch_results = infer_batch_images(
                model=model,
                class_names=class_names,
                image_folder=image_folder,
                save_dir=os.path.join(save_dir, "batch"),
                batch_size=16,
                visualize=True
            )
            
            if batch_results is not None:
                results['batch_results'] = {
                    'count': len(batch_results),
                    'file': os.path.join(save_dir, "batch", "batch_results.json")
                }
    
    # Step 4: Benchmark model performance
    print("\nStep 4: Benchmarking model performance...")
    
    if image_path is not None:
        benchmark_results = benchmark_model_performance(
            model=model,
            image_path=image_path,
            class_names=class_names,
            device=device,
            num_iterations=100,
            save_dir=save_dir
        )
        
        if benchmark_results is not None:
            results['benchmark'] = benchmark_results
    
    print("\nInference pipeline completed successfully.")
    return results



In [None]:
# --- Cell 8: Run Inference (User Code) ---
"""
## Run Inference

This is where you run the actual inference pipeline with your trained models and images.
Uncomment and modify the code below to perform inference on your own images.
"""

# Option 1: Run complete inference pipeline
"""
results = run_complete_inference_pipeline(
    model_name='improved_cnn',  # Replace with your model name
    checkpoint_type='best',
    image_path='path/to/your/image.jpg',  # Replace with your image path
    image_folder=None,  # Optionally specify a folder for batch inference
    model_checkpoints_dir='model_checkpoints',
    save_dir='inference_results'
)
"""

# Option 2: Load model and perform specific inference tasks
"""
# Load model
model_name = 'improved_cnn'  # Replace with your model name
checkpoint_path = 'model_checkpoints/improved_cnn/best_model.pth'  # Replace with your checkpoint path
class_names_path = 'model_checkpoints/improved_cnn/class_names.txt'  # Replace with your class names path

# Load class names
class_names = load_class_names(class_names_path)

# Determine number of classes
num_classes = len(class_names) if class_names else 62  # Default to 62 classes (10 digits + 26*2 letters)

# Load model
model = load_model_checkpoint(checkpoint_path, model_name, num_classes, device)

# Single image inference
image_path = 'path/to/your/image.jpg'  # Replace with your image path
prediction = infer_single_image(model, class_names, image_path)

# Batch inference
image_folder = 'path/to/your/images'  # Replace with your image folder
batch_results = infer_batch_images(model, class_names, image_folder)

# Handwritten text recognition
text_image_path = 'path/to/your/text_image.jpg'  # Replace with your text image path
recognized_text, character_info, visualization_images = recognize_handwritten_text(
    model, class_names, text_image_path
)
"""

# Option 3: Compare multiple models
"""
# Load models
models_configs = []

# Model 1: Basic CNN
model1_name = 'basic_cnn'
model1_path = 'model_checkpoints/basic_cnn/best_model.pth'
class_names1_path = 'model_checkpoints/basic_cnn/class_names.txt'
class_names1 = load_class_names(class_names1_path)
num_classes1 = len(class_names1)
model1 = load_model_checkpoint(model1_path, model1_name, num_classes1, device)
if model1 is not None:
    models_configs.append({
        'name': model1_name,
        'model': model1,
        'class_names': class_names1
    })

# Model 2: Improved CNN
model2_name = 'improved_cnn'
model2_path = 'model_checkpoints/improved_cnn/best_model.pth'
class_names2_path = 'model_checkpoints/improved_cnn/class_names.txt'
class_names2 = load_class_names(class_names2_path)
num_classes2 = len(class_names2)
model2 = load_model_checkpoint(model2_path, model2_name, num_classes2, device)
if model2 is not None:
    models_configs.append({
        'name': model2_name,
        'model': model2,
        'class_names': class_names2
    })

# Model 3: VGG19
model3_name = 'vgg19'
model3_path = 'model_checkpoints/vgg19/best_model.pth'
class_names3_path = 'model_checkpoints/vgg19/class_names.txt'
class_names3 = load_class_names(class_names3_path)
num_classes3 = len(class_names3)
model3 = load_model_checkpoint(model3_path, model3_name, num_classes3, device)
if model3 is not None:
    models_configs.append({
        'name': model3_name,
        'model': model3,
        'class_names': class_names3
    })

# Compare models
sample_image_path = 'path/to/your/sample_image.jpg'  # Replace with your sample image path
comparison_results = compare_models_performance(models_configs, sample_image_path)
"""

print("This notebook is ready for inference with handwritten character recognition models.")
print("Uncomment one of the inference options above and run this cell to start inference.")