In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNRNN(nn.Module):
    def __init__(self, num_chars, input_channels=1):
        super(CNNRNN, self).__init__()
        
        # CNN part
        self.cnn = nn.Sequential(
            # Layer 1
            nn.Conv2d(input_channels, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            
            # Layer 2
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            
            # Layer 3
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)),
            
            # Layer 4
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)),
            
            # Layer 5
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        )
        
        # Calculate CNN output features
        self.lstm_hidden_size = 256
        
        # Fix 1: Add a projection layer to reduce feature dimensionality
        # Instead of directly connecting CNN to RNN, add a 1x1 conv to reduce channels
        self.feature_projection = nn.Conv2d(256, 256, kernel_size=1, stride=1)
        
        # Fix 2: Update LSTM to handle the actual feature size
        # For a 32px height image, after CNN we'll have:
        # Height: 32 -> 16 -> 8 -> 8 -> 8 -> 8 (after all pooling layers)
        # So each time step has 256*8 = 2048 features
        self.rnn_input_size = 256  # We'll fix this in forward() by reshaping properly
        
        # RNN part
        self.lstm = nn.LSTM(
            input_size=self.rnn_input_size,
            hidden_size=self.lstm_hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )
        
        # Final classifier
        self.classifier = nn.Linear(self.lstm_hidden_size * 2, num_chars)
    
    def forward(self, x):
        # CNN feature extraction
        features = self.cnn(x)  # [batch, channels, height, width]
        
        # Get dimensions
        batch_size, channels, height, width = features.size()
        
        # Fix: Apply 1x1 convolution to reduce features
        features = self.feature_projection(features)  # [batch, 256, height, width]
        
        # Fix: Reshape for RNN - properly handle height dimension
        # Option 1: Flatten height into channels
        # features = features.view(batch_size, width, channels * height)
        
        # Option 2: Average across height dimension (recommended)
        features = features.mean(dim=2)  # [batch, channels, width]
        features = features.permute(0, 2, 1)  # [batch, width, channels]
        
        # RNN sequence processing
        rnn_output, _ = self.lstm(features)  # [batch, width, 2*hidden_size]
        
        # Project to character space
        logits = self.classifier(rnn_output)  # [batch, width, num_chars]
        
        # Apply log softmax for CTC loss
        log_probs = F.log_softmax(logits, dim=2)
        
        return log_probs
    
    def predict(self, x):
        log_probs = self.forward(x)
        predictions = torch.argmax(log_probs, dim=2)
        return predictions


# Usage example with CTC loss
def train_batch(model, criterion, optimizer, images, targets, target_lengths):
    """Train the model on a batch of data"""
    optimizer.zero_grad()
    
    # Forward pass
    log_probs = model(images)
    
    # Prepare for CTC loss
    batch_size = log_probs.size(0)
    input_lengths = torch.full(size=(batch_size,), fill_value=log_probs.size(1), dtype=torch.long)
    
    # Compute CTC loss
    loss = criterion(log_probs.transpose(0, 1), targets, input_lengths, target_lengths)
    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()
    
    return loss.item()


# Decoding function for inference
def decode_predictions(predictions, idx_to_char, blank_idx=0):
    """
    Decode the model predictions using CTC best path decoding
    
    Args:
        predictions: Model predictions [batch_size, sequence_length]
        idx_to_char: Mapping from index to character
        blank_idx: Index of the blank token
        
    Returns:
        List of decoded texts
    """
    results = []
    
    for pred in predictions:
        # Remove duplicates
        collapsed = []
        previous = -1
        for p in pred:
            if p != previous:
                collapsed.append(p.item())
            previous = p.item()
        
        # Remove blanks
        decoded = [idx_to_char[idx] for idx in collapsed if idx != blank_idx]
        results.append(''.join(decoded))
    
    return results

In [4]:
import torch
import matplotlib.pyplot as plt
import Levenshtein

def visualize_predictions(model, dataloader, char_to_idx, device, num_samples=5):
    """
    Visualize model predictions on sample images
    
    Args:
        model: Trained CNNRNN model
        dataloader: DataLoader containing samples to predict
        char_to_idx: Character to index mapping
        device: Device to run inference on
        num_samples: Number of samples to visualize
    """
    model.eval()  # Set model to evaluation mode
    
    # Create index to char mapping for decoding
    idx_to_char = {idx: char for char, idx in char_to_idx.items()}
    blank_idx = 0  # CTC blank index
    
    # Get a batch of data
    for images, texts, text_lengths in dataloader:
        if images is None:
            continue
            
        # Only take the requested number of samples
        n = min(num_samples, images.size(0))
        images = images[:n].to(device)
        texts = texts[:n]
        text_lengths = text_lengths[:n]
        
        # Make predictions
        with torch.no_grad():
            predictions = model.predict(images)
            decoded_preds = decode_predictions(predictions, idx_to_char, blank_idx)
        
        # Extract ground truth texts
        ground_truths = []
        for i, length in enumerate(text_lengths):
            gt = ''.join([idx_to_char.get(texts[i, j].item(), '') for j in range(length) if texts[i, j] > 0])
            ground_truths.append(gt)
        
        # Calculate character error rate for each sample
        cers = []
        for pred, gt in zip(decoded_preds, ground_truths):
            cer = Levenshtein.distance(pred, gt) / max(len(gt), 1)
            cers.append(cer)
        
        # Display images with predictions
        plt.figure(figsize=(15, 4 * n))
        
        for i in range(n):
            plt.subplot(n, 1, i+1)
            
            # Get and display image (remove channel dimension and transpose if needed)
            img = images[i, 0].cpu().numpy()
            plt.imshow(img, cmap='gray')
            
            # Display prediction and ground truth
            correct = decoded_preds[i] == ground_truths[i]
            color = 'green' if correct else 'red'
            
            title = f"Ground truth: '{ground_truths[i]}'\n"
            title += f"Prediction: '{decoded_preds[i]}'\n"
            title += f"CER: {cers[i]:.2f}"
            
            plt.title(title, color=color)
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Only process one batch
        break

# Usage
def check_trained_model(model, test_loader, char_to_idx, device, num_samples=5):
    # Ensure model is loaded properly and on the right device
    model = model.to(device)
    # Visualize predictions
    visualize_predictions(model, test_loader, char_to_idx, device, num_samples)
    
# Example call:
# check_trained_model(trained_model, test_loader, char_to_idx, device, num_samples=8)

In [9]:
base_model = CNNRNN(num_chars=77)
base_model.load_state_dict(torch.load('model_weights/base_model_weight.pt'))
base_model.eval()

CNNRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, tra