# CRNN implemented using PyTorch

In [None]:
# Imports


import time
import torch
import random

import torch.nn as nn

import torch.nn.functional as F



from torch.utils.data import DataLoader, Dataset

import os

from PIL import Image, ImageFilter, ImageEnhance

from torchvision import transforms



import torch

from torch.utils.data import Dataset

from torchvision import transforms

from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt



import numpy as np

import torch.optim as optim

from torch.nn import CTCLoss


from sklearn.model_selection import train_test_split

from torch.utils.data import Subset

In [None]:
# Global variables

char_list = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 '

max_label_len = 0


In [None]:
import os
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance
from collections import Counter
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

# Encode the label to numerical format with padding or truncation to fixed length
def encode_to_labels(txt, char_list, target_len=11):
    dig_lst = []
    for char in txt:
        try:
            dig_lst.append(char_list.index(char))
        except ValueError:
            print(f"Character not found: {char}")

    # Truncate or pad to fixed length
    if len(dig_lst) > target_len:
        dig_lst = dig_lst[:target_len]  # Truncate
    else:
        padding_token = len(char_list) - 1  # Padding token (e.g., blank character)
        dig_lst.extend([padding_token] * (target_len - len(dig_lst)))  # Pad

    return dig_lst

# Custom Dataset class with augmentation and tracking
class RecogDataset(Dataset):
    def __init__(self, data_path, char_list, target_size=(32, 512), pred_len=11, augment=True, skip_text=True):
        self.images = []
        self.labels = []
        self.label_length = []
        self.input_length = []
        self.orig_txt = []
        self.char_list = char_list
        self.pred_len = pred_len  # Fixed length of labels
        self.augment = augment
        self.character_counts = Counter()

        # Define the transform with color augmentation
        self.transform = transforms.Compose([
            transforms.RandomRotation(degrees=2),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.02, 0.02),
                scale=(0.9, 1.1),
                shear=(-2, 2),
            ),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
            transforms.Resize(target_size),  # ðŸ”§ Ensure all images have same size
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])

        # Load all image paths and labels
        for root, dirs, files in os.walk(data_path):
            for file in files:
                if file.endswith(('.png', '.jpg', '.jpeg')):  # Filter image files
                    image_path = os.path.join(root, file)
                    label = file.split('_')[1].split('.')[0]  # Extract label from filename
                    if skip_text and (len(label) != 11 or not any(char.isdigit() for char in label)):
                        continue
                    # Track character counts for balancing
                    self.character_counts.update(label)

                    # Load and process the image
                    image = Image.open(image_path).convert('L')  # Convert to grayscale
                    image = self.transform(image)

                    # Process label and encode it to fixed length
                    encoded_label = encode_to_labels(label, self.char_list, target_len=self.pred_len)
                    label_len = min(len(label), self.pred_len)  # Actual length of original label

                    self.images.append(image)
                    self.labels.append(encoded_label)
                    self.label_length.append(label_len)
                    self.input_length.append(self.pred_len)
                    self.orig_txt.append(label)

        # Data augmentation to balance character distribution
        if self.augment:
            self.balance_dataset()

    def balance_dataset(self):
        max_count = max(self.character_counts.values())  # Target count for balancing

        # Create additional images for underrepresented characters
        augmented_images = []
        augmented_labels = []
        augmented_label_lengths = []
        augmented_input_lengths = []
        augmented_orig_txts = []

        for idx in range(len(self.images)):
            label = self.orig_txt[idx]
            char_count = Counter(label)

            for char in char_count:
                # Augment the image multiple times
                random_max_count = max_count * (random.randint(50 + random.randint(0, 10), 100 + random.randint(-10, 0)) / 100.0)
                while self.character_counts[char] < random_max_count:
                    augmented_image = self.augment_image(self.images[idx])
                    augmented_images.append(augmented_image)
                    augmented_labels.append(self.labels[idx])
                    augmented_label_lengths.append(self.label_length[idx])
                    augmented_input_lengths.append(self.input_length[idx])
                    augmented_orig_txts.append(label)
                    self.character_counts[char] += 1

        # Append augmented data to the dataset
        self.images.extend(augmented_images)
        self.labels.extend(augmented_labels)
        self.label_length.extend(augmented_label_lengths)
        self.input_length.extend(augmented_input_lengths)
        self.orig_txt.extend(augmented_orig_txts)

    def augment_image(self, image):
        # Apply random transformations to augment the image
        image_pil = transforms.ToPILImage()(image)
        enhancer = ImageEnhance.Brightness(image_pil)
        image_pil = enhancer.enhance(1 + random.random() * 0.25)  # Increase brightness
        return self.transform(image_pil)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        label_len = self.label_length[idx]
        input_len = self.input_length[idx]
        orig_txt = self.orig_txt[idx]

        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long),
            'label_len': torch.tensor(label_len, dtype=torch.long),
            'input_len': torch.tensor(input_len, dtype=torch.long),
            'orig_txt': orig_txt
        }

    
    def plot_character_distribution(self):
        # Plot a graph of character distribution
        char_counts = dict(self.character_counts)
        
        # Sort characters and counts based on the character keys
        sorted_characters = sorted(char_counts.keys())
        sorted_counts = [char_counts[char] for char in sorted_characters]
        
        # Plot the distribution
        plt.figure(figsize=(10, 5))
        plt.plot(sorted_characters, sorted_counts, marker='o', linestyle='-', color='b')
        plt.xlabel('Characters')
        plt.ylabel('Counts')
        plt.title('Character Distribution')
        plt.grid()
        plt.show()


In [None]:
# CRNN Model Architecture

class CRNN(nn.Module):



    def __init__(self, img_channel, img_height, img_width, num_class,

                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):

        super(CRNN, self).__init__()



        self.cnn, (output_channel, output_height, output_width) = self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)



        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)



        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)

        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)



        self.dense = nn.Linear(2 * rnn_hidden, num_class)



    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):

        assert img_height % 16 == 0

        assert img_width % 4 == 0



        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]

        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]

        strides = [1, 1, 1, 1, 1, 1, 1]

        paddings = [1, 1, 1, 1, 1, 1, 0]



        cnn = nn.Sequential()



        def conv_relu(i, batch_norm=False):

            # shape of input: (batch, input_channel, height, width)

            input_channel = channels[i]

            output_channel = channels[i+1]



            cnn.add_module(

                f'conv{i}',

                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])

            )



            if batch_norm:

                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))



            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)

            cnn.add_module(f'relu{i}', relu)



        # size of image: (channel, height, width) = (img_channel, img_height, img_width)

        conv_relu(0)

        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))

        # (64, img_height // 2, img_width // 2)



        conv_relu(1)

        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))

        # (128, img_height // 4, img_width // 4)



        conv_relu(2)

        conv_relu(3)

        cnn.add_module(

            'pooling2',

            nn.MaxPool2d(kernel_size=(2, 1))

        )  # (256, img_height // 8, img_width // 4)



        conv_relu(4, batch_norm=True)

        conv_relu(5, batch_norm=True)

        cnn.add_module(

            'pooling3',

            nn.MaxPool2d(kernel_size=(2, 1))

        )  # (512, img_height // 16, img_width // 4)



        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)



        output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 - 1

        return cnn, (output_channel, output_height, output_width)



    def forward(self, images):

        # shape of images: (batch, channel, height, width)



        conv = self.cnn(images)

        batch, channel, height, width = conv.size()



        conv = conv.view(batch, channel * height, width)

        conv = conv.permute(2, 0, 1)  # (width, batch, feature)

        seq = self.map_to_seq(conv)



        recurrent, _ = self.rnn1(seq)

        recurrent, _ = self.rnn2(recurrent)



        output = self.dense(recurrent)

        return output  # shape: (seq_len, batch, num_class)

In [None]:
class ModifiedCRNN1(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,

                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):

        super(ModifiedCRNN1, self).__init__()



        self.cnn, (output_channel, output_height, output_width) = self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)



        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)



        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2 * rnn_hidden, num_class)



    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):

        assert img_height % 16 == 0

        assert img_width % 4 == 0



        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]

        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]

        strides = [1, 1, 1, 1, 1, 1, 1]

        paddings = [1, 1, 1, 1, 1, 1, 0]



        cnn = nn.Sequential()



        def conv_relu(i, batch_norm=False):

            input_channel = channels[i]

            output_channel = channels[i+1]



            cnn.add_module(

                f'conv{i}',

                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])

            )



            if batch_norm:

                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))



            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)

            cnn.add_module(f'relu{i}', relu)



        conv_relu(0)

        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))

        conv_relu(1)

        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))

        conv_relu(2)

        conv_relu(3)

        cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 1)))

        conv_relu(4, batch_norm=True)

        conv_relu(5, batch_norm=True)

        cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 1)))

        conv_relu(6)



        output_channel, output_height, output_width = channels[-1], img_height // 16 - 1, img_width // 4 - 1

        return cnn, (output_channel, output_height, output_width)



    def forward(self, images):

      # shape of images: (batch, channel, height, width)

      conv = self.cnn(images)

      batch, channel, height, width = conv.size()

      #print(f"Conv shape: {conv.shape}")  # Debugging output



      conv = conv.view(batch, channel * height, width)

      #print(f"Shape before map_to_seq: {conv.shape}")  # Debugging output



      conv = conv.permute(2, 0, 1)  # (width, batch, feature)

      #print(f"Shape after permute: {conv.shape}")  # Debugging output



      seq = self.map_to_seq(conv)

      #print(f"Shape after map_to_seq: {seq.shape}")  # Debugging output



      recurrent, _ = self.rnn1(seq)

      #print(f"Shape after rnn1: {recurrent.shape}")  # Debugging output



      output = self.dense(recurrent)

      #print(f"Shape after dense: {output.shape}")  # Debugging output



      return output





In [None]:
class ModifiedCRNN2(nn.Module):
    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
        super(ModifiedCRNN2, self).__init__()

        self.cnn, (output_channel, output_height, output_width) = self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        # Adjust the input size to map_to_seq based on the CNN output
        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        channels = [img_channel, 64, 128, 256, 512, 1024]
        kernel_sizes = [3] * len(channels[:-1])
        strides = [1] * len(channels[:-1])
        paddings = [1] * len(channels[:-1])

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            input_channel = channels[i]
            output_channel = channels[i + 1]

            cnn.add_module(f'conv{i}', nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i]))

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # Build the CNN layers
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))

        conv_relu(2)
        conv_relu(3)
        cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 1)))

        conv_relu(4, batch_norm=True)
        cnn.add_module('pooling3', nn.MaxPool2d(kernel_size=(2, 1)))

        output_channel = channels[-1]
        output_height = img_height // 16  # Adjust this based on pooling
        output_width = img_width // 4  # Adjust this based on pooling

        return cnn, (output_channel, output_height, output_width)

    def forward(self, images):
        # shape of images: (batch, channel, height, width)
        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        # Reshape the output for LSTM
        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, features)

        # Pass through the mapping layer
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)

        return output  # shape: (seq_len, batch, num_class)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import Counter

# Updated Training function to use test dataset for final evaluation
def train(model, train_loader, val_loader, test_loader, num_epochs, initial_lr, char_list, model_save_path, device):
    ctc_loss = nn.CTCLoss(blank=len(char_list) - 1, zero_infinity=True)
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5)

    # Resume training if a saved model exists
    start_epoch = 0
    best_val_accuracy = 0.0
    if os.path.exists(model_save_path):
        print(f"Resuming training from {model_save_path}...")
        checkpoint = torch.load(model_save_path, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_accuracy = checkpoint['best_val_accuracy']
        print(f"Resumed from epoch {start_epoch}, best val accuracy: {best_val_accuracy:.4f}")

    train_losses, val_losses, val_accuracies = [], [], []
    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch in train_loader:
            images = batch['image'].to(device)  # Move images to device
            targets = batch['label'].to(device)  # Move targets to device
            target_lengths = batch['label_len'].to(device)  # Move target lengths to device

            optimizer.zero_grad()
            outputs = model(images)
            output_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)

            # Compute loss
            loss = ctc_loss(outputs.log_softmax(2), targets, output_lengths, target_lengths)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        train_loss = epoch_loss / len(train_loader)
        train_losses.append(train_loss)

        # Validate the model
        val_loss, val_accuracy = validate_model(model, val_loader, ctc_loss, char_list, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        # Adjust learning rate based on validation loss
        scheduler.step(val_loss)

        # Save the best model based on validation accuracy
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Append accuracy to the model save path
            model_filename = f"{model_save_path.split('.pth')[0]}_acc_{val_accuracy:.4f}.pth"
            save_model(model, optimizer, scheduler, epoch, val_accuracy, model_filename)
            save_model(model, optimizer, scheduler, epoch, val_accuracy, model_save_path)
            print(f"Saved the best model at epoch {epoch + 1} with val accuracy: {val_accuracy:.4f}")

        print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Val Accuracy: {val_accuracy:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

    # Evaluate on the test dataset after training
    test_loss, test_accuracy = validate_model(model, test_loader, ctc_loss, char_list, device)
    print(f"Final Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    return train_losses, val_losses, val_accuracies

# Validation function remains unchanged
def validate_model(model, val_loader, ctc_loss, char_list, device):
    model.eval()
    val_loss = 0.0
    total_correct = 0
    total_characters = 0

    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device)  # Move images to device
            targets = batch['label'].to(device)  # Move targets to device
            target_lengths = batch['label_len'].to(device)  # Move target lengths to device

            outputs = model(images)
            output_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long).to(device)

            # Compute loss
            loss = ctc_loss(outputs.log_softmax(2), targets, output_lengths, target_lengths)
            val_loss += loss.item()

            # Decode predictions and compare with targets
            predictions = decode_predictions(outputs, char_list)
            ground_truths = decode_ground_truths(targets, char_list)

            for pred, gt in zip(predictions, ground_truths):
                total_correct += sum(p == g for p, g in zip(pred, gt))
                total_characters += len(gt)

    val_accuracy = total_correct / total_characters if total_characters > 0 else 0.0
    return val_loss / len(val_loader), val_accuracy

# Save model function
def save_model(model, optimizer, scheduler, epoch, best_val_accuracy, path):
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'epoch': epoch,
        'best_val_accuracy': best_val_accuracy
    }, path)
    print(f"Model saved to {path}")

# Load model function
def load_model(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        print(f"Model loaded from {path}")
        return checkpoint
    else:
        print(f"No model found at {path}")
        return None

# Decode predictions using CTC
def decode_predictions(outputs, char_list):
    outputs = outputs.softmax(2).argmax(2).transpose(1, 0).cpu().numpy()
    decoded = []
    for sequence in outputs:
        chars = []
        prev_char = None
        for c in sequence:
            if c != prev_char and c != len(char_list) - 1:  # Avoid duplicates and blank token
                chars.append(char_list[c])
            prev_char = c
        decoded.append(''.join(chars))
    return decoded

# Decode ground truths
def decode_ground_truths(targets, char_list):
    decoded = []
    for sequence in targets:
        chars = [char_list[c] for c in sequence if c != len(char_list) - 1]
        decoded.append(''.join(chars))
    return decoded

# Visualize predictions for a few random test data points
def visualize_predictions(model, test_loader, char_list, device, num_samples=5):
    model.eval()
    with torch.no_grad():
        # Randomly pick a few samples from the test_loader
        test_samples = random.sample(list(test_loader.dataset), num_samples)
        for sample in test_samples:
            image = sample['image'].unsqueeze(0).to(device)  # Add batch dimension and move to device
            target = sample['label'].to(device)  # Move target to device
            
            output = model(image)
            prediction = decode_predictions(output, char_list)[0]
            ground_truth = decode_ground_truths(target.unsqueeze(0), char_list)[0]  # Decode ground truth
            
            correct_chars = sum(p == g for p, g in zip(prediction, ground_truth))
            total_chars = len(ground_truth)

            plt.imshow(image.squeeze(0).squeeze(0).cpu().numpy(), cmap='gray')
            plt.title(f'Prediction: {prediction}\nGround Truth: {ground_truth}\n'
                      f'Correct Characters: {correct_chars}/{total_chars}')
            plt.axis('off')
            plt.show()

# Evaluate model performance on the test dataset
def evaluate_model(model, test_loader, char_list, device):
    test_loss, test_accuracy = validate_model(model, test_loader, nn.CTCLoss(blank=len(char_list) - 1, zero_infinity=True), char_list, device)
    print(f"Final Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

# Main evaluation and visualization flow
def main_evaluation_and_visualization(model, val_loader, test_loader, char_list, device):
    # Visualize predictions on a few test data points
    print("Visualizing predictions on test data:")
    visualize_predictions(model, test_loader, char_list, device)

    # Evaluate model performance on the test dataset
    print("Evaluating model performance on test data:")
    evaluate_model(model, test_loader, char_list, device)

# Visualize predictions with statistics
def visualize_predictions_with_stats(model, loader, char_list, device):
    model.eval()
    start_time = time.time()

    total_correct = 0
    total_characters = 0
    pred_char_counts = Counter()
    true_char_counts = Counter()

    with torch.no_grad():
        for batch in loader:
            images = batch['image'].to(device)  # Move images to device
            targets = batch['label'].to(device)  # Move targets to device

            # Generate outputs and decode predictions
            outputs = model(images)
            predictions = decode_predictions(outputs, char_list)
            ground_truths = decode_ground_truths(targets, char_list)

            for i in range(len(predictions)):  # Iterate through each example in the batch
                pred = predictions[i]
                gt = ground_truths[i]

                # Count correct characters and total characters
                correct_chars = sum(p == g for p, g in zip(pred, gt))
                total_correct += correct_chars
                total_chars = len(gt)
                total_characters += total_chars

                # Update character frequency counters
                pred_char_counts.update(pred)
                true_char_counts.update(gt)

                # Display example visualizations
                if i < 5:  # Limit to 5 examples for display
                    plt.imshow(images[i].squeeze(0).cpu().numpy(), cmap='gray')
                    plt.title(f'Prediction: {pred}\nGround Truth: {gt}\n'
                              f'Correct Characters: {correct_chars}/{total_chars}')
                    plt.show()

    # Compute and display accuracy
    accuracy = total_correct / total_characters if total_characters > 0 else 0.0
    processing_time = time.time() - start_time

    # Print frequency comparison
    print("\nCharacter Frequency Comparison:")
    print(f"{'Character':<10} {'Predicted':<10} {'True':<10}")
    for char in sorted(set(pred_char_counts.keys()).union(set(true_char_counts.keys()))):
        print(f"{char:<10} {pred_char_counts[char]:<10} {true_char_counts[char]:<10}")

    # Print final accuracy and processing time
    print(f"\nOverall Accuracy: {accuracy:.4f}")
    print(f"Processing Time: {processing_time:.2f} seconds")

    return accuracy, processing_time


In [None]:
# Hyperparameters and settings
img_channel = 1  # Grayscale
img_height = 32
img_width = 352
learning_rate = 0.001
num_classes = len(char_list)
batch_size = 4
num_epochs = 512
model_save_path = '/kaggle/working/modified_crnn_model/checkpoint.pth'
#model_save_path = "/kaggle/input/modified_ccrn_model/pytorch/default/1/modified_ccrn_model.pth"
data_path = '/kaggle/input/dataset-cc/final_dataset/recognition/images_with_labels'

In [None]:
# Getting the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

char_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
             'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 
             'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 
             'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 
             'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 
             'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
             'y', 'z', ' ']  # Define your character set

# Prepare dataset and dataloaders
full_dataset = RecogDataset(data_path=data_path, char_list=char_list)
#full_dataset = RecogDataset(data_path=data_path, char_list=char_list, target_size=(32, 512), pred_len=11, augment=False)

# Plot character distribution
full_dataset.plot_character_distribution()

# Split the dataset into train (70%), validation (20%), and test (10%)
train_indices, temp_indices = train_test_split(
    list(range(len(full_dataset))), test_size=0.3, random_state=42
)
val_indices, test_indices = train_test_split(
    temp_indices, test_size=1/3, random_state=42  # 1/3 of 30% is 10%
)

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)
test_dataset = Subset(full_dataset, test_indices)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Instantiate the model
#model = ModifiedCRNN1(img_channel, img_height, img_width, num_classes)
model = ModifiedCRNN2(img_channel, img_height, img_width, num_classes)

model.to(device)

In [None]:
train_losses, val_losses, val_accuracies = train(model, train_loader, val_loader, test_loader, num_epochs, learning_rate, char_list, model_save_path, device)

In [None]:
_ = plt.plot(train_losses)
_ = plt.plot(val_losses)
_ = plt.title("Train and Validation Loss")
_ = plt.xlabel("Epochs")
_ = plt.ylabel("Loss")
_ = plt.legend(["Training Loss", "Validation Loss"])

In [None]:
# Example usage after training
main_evaluation_and_visualization(model, val_loader, test_loader, char_list, device)

In [None]:

# Load the model (optional)
#loaded_model = ModifiedCRNN1(img_channel, img_height, img_width, num_classes)
loaded_model = ModifiedCRNN2(img_channel, img_height, img_width, num_classes)
loaded_model.to(device)
checkpoint = torch.load(model_save_path, weights_only=True)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
# Visualize some predictions 
visualize_predictions(loaded_model, val_loader, char_list, device)

In [None]:
# Visualize predictions and analyze statistics
accuracy, processing_time = visualize_predictions_with_stats(loaded_model, val_loader, char_list, device)



In [None]:
print(f"Final Validation Accuracy: {accuracy:.4f}")
print(f"Model Processing Time: {processing_time:.2f} seconds")