In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import os
from ocr_model import OCRModel

In [12]:
class CaptchaDataset(Dataset):
    def __init__(self, image_paths, labels, img_width=200, img_height=50, transform=None,
                 char_to_idx=None, idx_to_char=None):
        self.image_paths = image_paths
        # Filter out non-alphanumeric characters from labels
        self.labels = [''.join(char for char in label if char.isalnum()) for label in labels]
        self.img_width = img_width
        self.img_height = img_height
        self.transform = transform
        
        if char_to_idx is None or idx_to_char is None:
            # Create character mappings if not provided
            self.characters = sorted(list(set(char for label in self.labels for char in label)))
            self.char_to_idx = {char: idx + 1 for idx, char in enumerate(self.characters)}
            self.idx_to_char = {idx + 1: char for idx, char in enumerate(self.characters)}
            self.idx_to_char[0] = ''
        else:
            self.char_to_idx = char_to_idx
            self.idx_to_char = idx_to_char
            
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
        label_indices = [self.char_to_idx[char] for char in label]
        return {
            'image': image,
            'label': torch.tensor(label_indices, dtype=torch.long),
            'label_length': torch.tensor(len(label_indices), dtype=torch.long),
            'text': label
        }

def custom_collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    batch.sort(key=lambda x: len(x['label']), reverse=True)
    max_length = len(batch[0]['label'])
    
    images = []
    labels = []
    label_lengths = []
    texts = []
    
    for item in batch:
        images.append(item['image'])
        curr_label = item['label']
        curr_len = len(curr_label)
        if curr_len < max_length:
            padding = torch.zeros(max_length - curr_len, dtype=torch.long)
            curr_label = torch.cat([curr_label, padding])
        labels.append(curr_label)
        label_lengths.append(item['label_length'])
        texts.append(item['text'])
    
    images = torch.stack(images)
    labels = torch.stack(labels)
    label_lengths = torch.stack(label_lengths)
    
    return {
        'image': images,
        'label': labels,
        'label_length': label_lengths,
        'text': texts
    }



def decode_predictions(log_probs, idx_to_char):
    """Decode CTC output to text"""
    pred_indices = torch.argmax(log_probs, dim=2)
    batch_texts = []

    for pred in pred_indices:
        text = []
        for i in range(len(pred)):
            if i == 0 or pred[i] != pred[i-1]:
                if pred[i] != 0:  # 0 is CTC blank
                    text.append(idx_to_char[pred[i].item()])
        batch_texts.append(''.join(text))

    return batch_texts

def evaluate_model(model, val_loader, idx_to_char, device):
    model.eval()
    
    all_predictions = []
    all_targets = []
    char_correct = 0
    total_chars = 0
    
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)
            targets = batch['text']  # Original text labels
            
            # Get model predictions
            log_probs = model(images)
            predictions = decode_predictions(log_probs, idx_to_char)
            
            # Store predictions and targets
            all_predictions.extend(predictions)
            all_targets.extend(targets)
            
            # Calculate per-character accuracy
            for pred, target in zip(predictions, targets):
                # Pad shorter string with spaces to match lengths
                max_len = max(len(pred), len(target))
                pred = pred.ljust(max_len)
                target = target.ljust(max_len)
                
                # Count correct characters
                char_correct += sum(p == t for p, t in zip(pred, target))
                total_chars += max_len

    # Calculate sequence-level metrics
    correct_sequences = sum(p == t for p, t in zip(all_predictions, all_targets))
    total_sequences = len(all_predictions)
    sequence_accuracy = correct_sequences / total_sequences

    # Calculate character-level accuracy
    char_accuracy = char_correct / total_chars if total_chars > 0 else 0

    # Calculate character-level precision, recall, and F1 score
    char_tp = 0  # True positives
    char_fp = 0  # False positives
    char_fn = 0  # False negatives
    
    for pred, target in zip(all_predictions, all_targets):
        # Create sets of (char, position) tuples for both prediction and target
        pred_chars = set((c, i) for i, c in enumerate(pred))
        target_chars = set((c, i) for i, c in enumerate(target))
        
        # Calculate metrics
        char_tp += len(pred_chars & target_chars)
        char_fp += len(pred_chars - target_chars)
        char_fn += len(target_chars - pred_chars)
    
    # Calculate precision, recall, and F1
    precision = char_tp / (char_tp + char_fp) if (char_tp + char_fp) > 0 else 0
    recall = char_tp / (char_tp + char_fn) if (char_tp + char_fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return {
        'sequence_accuracy': sequence_accuracy,
        'char_accuracy': char_accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Define character set
characters = (
    [str(i) for i in range(10)] +  # 0-9
    [chr(i) for i in range(65, 91)] +  # A-Z
    [chr(i) for i in range(97, 123)]   # a-z
)
characters = sorted(characters)

# Create character mappings
char_to_idx = {char: idx + 1 for idx, char in enumerate(characters)}
idx_to_char = {idx + 1: char for idx, char in enumerate(characters)}
idx_to_char[0] = ''  # Add blank token



In [29]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the model
model = OCRModel()
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model = model.to(device)

# Set up data transforms
transform = transforms.Compose([
    transforms.Resize((50, 200)),
    transforms.ToTensor(),
])

# Load validation data
val_dir = "./datasets/validation"
image_paths = []
labels = []

for ext in ['*.png', '*.jpg']:
    for img_path in Path(val_dir).glob(ext):
        image_paths.append(str(img_path))
        label = img_path.stem.split('.')[0]
        labels.append(label)

print(f"Found {len(image_paths)} validation images")

# Create validation dataset and loader
val_dataset = CaptchaDataset(
    image_paths,
    labels,
    transform=transform,
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=custom_collate_fn
)

# Evaluate the model
metrics = evaluate_model(model, val_loader, idx_to_char, device)

# Print results
print("\nModel Evaluation Results:")
print(f"Sequence Accuracy: {metrics['sequence_accuracy']:.4f}")
print(f"Character Accuracy: {metrics['char_accuracy']:.4f}")
print(f"Character Precision: {metrics['precision']:.4f}")
print(f"Character Recall: {metrics['recall']:.4f}")
print(f"Character F1 Score: {metrics['f1']:.4f}")

Using device: cuda
Found 769 validation images

Model Evaluation Results:
Sequence Accuracy: 0.9350
Character Accuracy: 0.9794
Character Precision: 0.9816
Character Recall: 0.9801
Character F1 Score: 0.9808


In [34]:
# 