# Handwritten Character Recognition - Model Training

This notebook focuses on the training aspects of a handwritten character recognition system using PyTorch. It covers:
- Setting up the data pipeline with image loading and augmentations.
- Defining various Convolutional Neural Network (CNN) architectures, including custom CNNs and a VGG19-based transfer learning model.
- Utility functions for the training loop, model saving/loading, and evaluation.
- Conducting training experiments for the defined models.

## 1. Imports and Setup

In [None]:
# General utilities
import os
import random
import copy
import time

# Image processing and display
import matplotlib.pyplot as plt
from PIL import Image
import cv2 # OpenCV for image operations
import numpy as np

# PyTorch essentials
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models

# For dataset path handling (if needed)
from pathlib import Path

# Note: Kaggle API setup and dataset download cells from the consolidated notebook are omitted here.
# Users are expected to have the dataset available locally. 
# The 'data_root_example' variable in the experiment sections should be updated to the dataset path.

## 2. Device Configuration

In [None]:
# Device configuration
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# In a Jupyter Notebook, multiprocessing start method handling is less critical 
# than in standalone scripts, especially if not using num_workers > 0 in DataLoader
# or if issues arise. If needed, it can be set.
# try:
#     torch.multiprocessing.set_start_method('spawn', force=True)
#     print("Set multiprocessing start method to 'spawn'.")
# except RuntimeError as e:
#     print(f"Note: {e}. Multiprocessing start method might have been already set.")

## 3. Custom Data Augmentation Transforms

In [None]:
class RandomChoice(torch.nn.Module):
    """Randomly applies one of the given transforms with given probability"""
    def __init__(self, transforms, p=0.5):
        super().__init__()
        self.transforms = transforms
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            transform = random.choice(self.transforms)
            return transform(img)
        return img

class ThicknessTransform(torch.nn.Module):
    """Apply morphological operations to change stroke thickness.
    It randomly chooses between dilation (thicker) or erosion (thinner).
    Args:
        kernel_size (int): Size of the kernel for morphological operations (default: 3).
        iterations (int): Number of times to apply the operation (default: 1).
    """
    def __init__(self, kernel_size=3, iterations=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.iterations = iterations

    def __call__(self, img):
        img_cv = np.array(img)
        if len(img_cv.shape) == 3 and img_cv.shape[2] == 3:
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY)
        elif len(img_cv.shape) == 3 and img_cv.shape[2] == 1:
             img_cv = img_cv[:, :, 0]
        
        kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
        if random.random() > 0.5:
            processed_img = cv2.dilate(img_cv, kernel, iterations=self.iterations)
        else:
            processed_img = cv2.erode(img_cv, kernel, iterations=self.iterations)
        return Image.fromarray(processed_img, mode='L')

## 4. Handwriting Data Pipeline

In [None]:
class HandwritingDataPipeline:
    def __init__(self, data_root, image_size=(64, 64), batch_size=32, do_transform=True, test_split=0.15, val_split=0.15):
        self.data_root = data_root
        self.image_size = image_size
        self.batch_size = batch_size
        self.do_transform = do_transform
        self.test_split = test_split
        self.val_split = val_split
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self._setup_transforms()
        self._load_and_split_datasets()

    def _setup_transforms(self):
        if self.do_transform:
            self.train_transform = transforms.Compose([
                transforms.Resize(self.image_size),
                transforms.Grayscale(num_output_channels=1),
                RandomChoice([
                    transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10, fill=255),
                    transforms.RandomPerspective(distortion_scale=0.3, p=0.5, fill=255),
                    transforms.RandomRotation(15, fill=255),
                ], p=0.8),
                ThicknessTransform(kernel_size=random.choice([1,2,3]), iterations=random.choice([1,2])),
                transforms.RandomApply([
                    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 0.5))
                ], p=0.3),
                transforms.ColorJitter(brightness=0.3, contrast=0.3),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
                self.normalize,
                transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 3.3), value='random')
            ])
        else:
            self.train_transform = transforms.Compose([
                transforms.Resize(self.image_size),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
                self.normalize
            ])

        self.val_test_transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
            self.normalize
        ])

    def _load_and_split_datasets(self):
        full_dataset = datasets.ImageFolder(root=self.data_root)
        self.class_names = full_dataset.classes
        self.num_classes = len(self.class_names)

        total_size = len(full_dataset)
        test_size = int(total_size * self.test_split)
        remaining_size = total_size - test_size
        val_size = int(remaining_size * (self.val_split / (1.0 - self.test_split)))
        train_size = remaining_size - val_size

        if train_size <= 0 or val_size <=0 or test_size <=0:
            print(f"Warning: Dataset too small for current split ratios. Total: {total_size}")
            if total_size < 3:
                train_dataset, val_dataset, test_dataset = full_dataset, full_dataset, full_dataset
            else:
                train_size = max(1, int(total_size * 0.7))
                val_size = max(1, int(total_size * 0.15))
                test_size = total_size - train_size - val_size
                if test_size <= 0:
                    test_size = 1
                    val_size = total_size - train_size - test_size
                    if val_size <=0:
                        val_size = 1
                        train_size = total_size - val_size - test_size
                
        print(f"Attempting to split: Train={train_size}, Val={val_size}, Test={test_size}")
        try:
            train_temp_dataset, test_dataset_subset = torch.utils.data.random_split(full_dataset, [train_size + val_size, test_size],
                                                                          generator=torch.Generator().manual_seed(42))
            train_dataset_subset, val_dataset_subset = torch.utils.data.random_split(train_temp_dataset, [train_size, val_size],
                                                                       generator=torch.Generator().manual_seed(42))
        except Exception as e:
            print(f"Error during dataset splitting: {e}. Adjusting split sizes or check dataset.")
            print("Using full dataset for train/val/test due to splitting error. THIS IS NOT RECOMMENDED FOR ACTUAL TRAINING.")
            train_dataset_subset, val_dataset_subset, test_dataset_subset = full_dataset, full_dataset, full_dataset

        self.train_dataset = TransformedDataset(train_dataset_subset, transform=self.train_transform)
        self.val_dataset = TransformedDataset(val_dataset_subset, transform=self.val_test_transform)
        self.test_dataset = TransformedDataset(test_dataset_subset, transform=self.val_test_transform)
        
        self.sizes = {'train': len(self.train_dataset), 'val': len(self.val_dataset), 'test': len(self.test_dataset)}

    def get_loaders(self, shuffle_train=True, shuffle_val=False, shuffle_test=False):
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=shuffle_train, num_workers=0)
        val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=shuffle_val, num_workers=0)
        test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=shuffle_test, num_workers=0)
        return train_loader, val_loader, test_loader

    def get_class_labels(self):
        return self.class_names

class TransformedDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

## 5. Display Augmented Images (Helper Function)

In [None]:
def display_augmented_images(data_loader, num_images=5, num_augmentations=3):
    """Displays original and augmented images from the train_loader.
    Args:
        data_loader (DataLoader): The DataLoader for the training set.
        num_images (int): Number of unique images to display.
        num_augmentations (int): Number of augmented versions to show per image.
    """
    if data_loader is None or not hasattr(data_loader, 'dataset') or not hasattr(data_loader.dataset, 'subset') \
       or not hasattr(data_loader.dataset.subset, 'dataset') or not hasattr(data_loader.dataset.subset.dataset, 'class_to_idx') \
       or not hasattr(data_loader.dataset.subset.dataset, 'imgs'):
        print("Data loader is None or not structured as expected (TransformedDataset -> Subset -> ImageFolder). Cannot display images.")
        print("Please ensure the data pipeline was initialized correctly and is feeding this function.")
        return

    imagefolder_dataset = data_loader.dataset.subset.dataset
    class_to_idx = imagefolder_dataset.class_to_idx
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    subset_indices = data_loader.dataset.subset.indices
    if len(subset_indices) < num_images:
        print(f"Warning: Requested {num_images} images, but dataset only has {len(subset_indices)}. Displaying all available.")
        num_images = len(subset_indices)
    
    if num_images == 0:
        print("No images to display.")
        return

    random_subset_indices = random.sample(range(len(subset_indices)), num_images)
    fig = plt.figure(figsize=(num_augmentations * 3, num_images * 3))
    train_transform = data_loader.dataset.transform

    for i, random_idx_in_subset in enumerate(random_subset_indices):
        original_dataset_idx = subset_indices[random_idx_in_subset]
        original_path, true_label_idx = imagefolder_dataset.imgs[original_dataset_idx]
        class_name = idx_to_class[true_label_idx]
        original_pil = Image.open(original_path)

        ax = plt.subplot(num_images, num_augmentations + 1, i * (num_augmentations + 1) + 1)
        ax.imshow(original_pil.convert("RGB"))
        ax.set_title(f'Original: {class_name}')
        ax.axis('off')

        for j in range(num_augmentations):
            augmented_tensor = train_transform(original_pil.copy())
            ax = plt.subplot(num_images, num_augmentations + 1, i * (num_augmentations + 1) + j + 2)
            img_display = augmented_tensor.cpu().numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img_display = std * img_display + mean
            img_display = np.clip(img_display, 0, 1)
            if img_display.shape[2] == 1:
                 plt.imshow(img_display.squeeze(), cmap='gray')
            else:
                 plt.imshow(img_display)
            ax.set_title(f'Aug {j+1}: {class_name}')
            ax.axis('off')
    plt.tight_layout()
    plt.show()

## 6. Model Architectures

This section defines the neural network models used for character recognition.

### 6.1. Custom CNNs (`LetterCNN64`, `ImprovedLetterCNN`)

In [None]:
class LetterCNN64(nn.Module):
    def __init__(self, num_classes):
        super(LetterCNN64, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = self.relu4(self.fc1(x))
        x = self.fc2(x)
        return x

class ImprovedLetterCNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedLetterCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.bn_fc1 = nn.BatchNorm1d(1024)
        self.relu_fc1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 512)
        self.bn_fc2 = nn.BatchNorm1d(512)
        self.relu_fc2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = x.view(-1, 256 * 4 * 4)
        x = self.dropout1(self.relu_fc1(self.bn_fc1(self.fc1(x))))
        x = self.dropout2(self.relu_fc2(self.bn_fc2(self.fc2(x))))
        x = self.fc3(x)
        return x

### 6.2. VGG19 Transfer Learning Model (`VGG19HandwritingModel`)

In [None]:
class VGG19HandwritingModel(nn.Module):
    def __init__(self, num_classes, device, pretrained=True):
        super(VGG19HandwritingModel, self).__init__()
        self.device = device
        vgg19 = models.vgg19_bn(weights=models.VGG19_BN_Weights.IMAGENET1K_V1 if pretrained else None)
        vgg19 = vgg19.to(device)
        self.features = vgg19.features
        if pretrained:
            for param in self.features.parameters():
                param.requires_grad = False
        num_features_output = 512 * 2 * 2 
        self.classifier = nn.Sequential(
            nn.Linear(num_features_output, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        ).to(device)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 7. Training Utilities

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, save_dir='model_checkpoints'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print(f"Created directory: {save_dir}")
        
    start_time = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0
            total_samples = 0

            if dataloader is None or len(dataloader.dataset) == 0:
                print(f"Skipping {phase} phase as dataloader is None or dataset is empty.")
                continue

            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                total_samples += inputs.size(0)
            
            if total_samples == 0:
                epoch_loss = 0
                epoch_acc = 0
            else:
                epoch_loss = running_loss / total_samples
                epoch_acc = running_corrects.double() / total_samples

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                if scheduler:
                    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                        scheduler.step(epoch_loss)
                if epoch_acc > best_acc and total_samples > 0:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    best_epoch = epoch + 1
                    best_model_path = os.path.join(save_dir, 'best_model.pth')
                    torch.save({
                        'epoch': best_epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': epoch_loss,
                        'accuracy': best_acc.item(),
                        'scheduler_state_dict': scheduler.state_dict() if scheduler else None
                    }, best_model_path)
                    print(f"Best model saved to {best_model_path} (Epoch {best_epoch}, Val Acc: {best_acc:.4f})")
        
        if phase == 'train' and scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step()
        print()

    time_elapsed = time.time() - start_time
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f} at epoch {best_epoch}')
    if best_acc > 0:
        model.load_state_dict(best_model_wts)
    
    final_model_path = os.path.join(save_dir, 'final_model.pth')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'loss': history['val_loss'][-1] if history['val_loss'] else (history['train_loss'][-1] if history['train_loss'] else 0.0),
        'accuracy': history['val_acc'][-1] if history['val_acc'] else (history['train_acc'][-1] if history['train_acc'] else 0.0),
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    if history['train_loss'] and history['val_loss']:
        plt.figure(figsize=(10,5))
        plt.plot(history['train_loss'], label="Train Loss")
        plt.plot(history['val_loss'], label="Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.title("Training and Validation Loss Over Epochs")
        plt.show()
    
    return model, history

def load_model(model, optimizer, checkpoint_path, scheduler=None):
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint path {checkpoint_path} does not exist. Returning initial model.")
        return model, optimizer, scheduler, 0, 0.0, 0.0
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        except ValueError as e:
            print(f"Warning: Could not load optimizer state: {e}. Optimizer will be reinitialized.")
    if scheduler is not None and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
        try:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        except Exception as e:
            print(f"Warning: Could not load scheduler state: {e}. Scheduler may be reinitialized or use default state.")
    start_epoch = checkpoint.get('epoch', 0)
    loss = checkpoint.get('loss', 0.0)
    accuracy = checkpoint.get('accuracy', 0.0)
    print(f"Model loaded from {checkpoint_path}. Epoch: {start_epoch}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
    return model, optimizer, scheduler, start_epoch, loss, accuracy

def test_model(model, test_loader, criterion=None):
    if test_loader is None or len(test_loader.dataset) == 0:
        print("Test loader is None or dataset is empty. Skipping testing.")
        return 0.0, 0.0
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += inputs.size(0)
    if total_samples == 0:
        test_loss = 0.0
        test_acc = 0.0
    else:
        test_loss = running_loss / total_samples
        test_acc = running_corrects.double() / total_samples
    print(f'Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}')
    return test_loss, test_acc.item()

def freeze_layers(model, num_layers_to_freeze):
    """Freezes the first num_layers_to_freeze layers of the model's features."""
    if hasattr(model, 'features') and isinstance(model.features, nn.Sequential):
        layer_idx = 0
        for child in model.features.children():
            if isinstance(child, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
                 if layer_idx < num_layers_to_freeze:
                    for param in child.parameters():
                        param.requires_grad = False
                 layer_idx +=1
        print(f"Froze {min(num_layers_to_freeze, layer_idx)} layers in model.features.")
    else:
        params = list(model.parameters())
        actual_layers_to_freeze = min(num_layers_to_freeze, len(params))
        for i, param in enumerate(params):
            if i < actual_layers_to_freeze:
                param.requires_grad = False
        print(f"Froze first {actual_layers_to_freeze} parameter groups (layers) of the model.")

## 8. Training Experiments

This section details the training process for different models.
**Important**: Ensure `data_root_example` below points to the correct path of your dataset.

In [ ]:
# Define the root directory for your dataset.
# <<< USER: CHANGE THIS PATH to your dataset location >>>
data_root_example = "./datasets/handwritten-english/augmented_images/augmented_images1" 

# Check if the data_root_example path exists. If not, print a warning.
if not os.path.exists(data_root_example):
    print(f"\nWARNING: The directory '{data_root_example}' does not exist. \nPlease ensure your dataset is available at this path or update 'data_root_example'.")
    print("Training experiments will likely fail if the path is incorrect.")
else:
    print(f"Using dataset root: {data_root_example}")

### 8.1. Experiment 1: Training `ImprovedLetterCNN`

In [None]:
print("--- Experiment 1: Training ImprovedLetterCNN ---")

if os.path.exists(data_root_example):
    save_dir_cnn = 'model_checkpoints/cnn'
    if not os.path.exists(save_dir_cnn):
        os.makedirs(save_dir_cnn)
        print(f"Created checkpoint directory: {save_dir_cnn}")

    print(f"Initializing data pipeline for CNN experiment with data_root: {data_root_example}")
    cnn_pipeline = HandwritingDataPipeline(data_root=data_root_example, image_size=(64,64), batch_size=32, do_transform=True)
    train_loader_cnn, val_loader_cnn, test_loader_cnn = cnn_pipeline.get_loaders()
    num_classes_cnn = cnn_pipeline.num_classes
    
    if num_classes_cnn > 0 and len(train_loader_cnn.dataset) > 0:
        print(f"Number of classes for CNN: {num_classes_cnn}")
        print(f"Train samples: {len(train_loader_cnn.dataset)}, Val samples: {len(val_loader_cnn.dataset)}, Test samples: {len(test_loader_cnn.dataset)}")

        model_cnn = ImprovedLetterCNN(num_classes_cnn).to(device)
        criterion_cnn = nn.CrossEntropyLoss()
        optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=0.0005)
        scheduler_cnn = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cnn, mode='min', factor=0.1, patience=3, verbose=True)

        num_epochs_cnn_train = 20 
        print(f"Starting training for ImprovedLetterCNN for {num_epochs_cnn_train} epochs...")
        
        trained_model_cnn, history_cnn = train_model(model_cnn, train_loader_cnn, val_loader_cnn, 
                                                   criterion_cnn, optimizer_cnn, scheduler_cnn, 
                                                   num_epochs=num_epochs_cnn_train, save_dir=save_dir_cnn)
        
        print("\nTesting the final trained ImprovedLetterCNN model on the test set:")
        test_model(trained_model_cnn, test_loader_cnn, criterion_cnn)

        print("\nLoading the best saved ImprovedLetterCNN model and testing it on the test set:")
        best_model_cnn_instance = ImprovedLetterCNN(num_classes_cnn).to(device)
        dummy_optimizer_cnn = optim.Adam(best_model_cnn_instance.parameters()) 
        best_cnn_model_loaded, _, _, _, _, _ = load_model(best_model_cnn_instance, 
                                                       dummy_optimizer_cnn, 
                                                       os.path.join(save_dir_cnn, 'best_model.pth'))
        test_model(best_cnn_model_loaded, test_loader_cnn, criterion_cnn)
        
        # Optional: Display augmented images from the CNN training loader
        print("\nDisplaying a sample of augmented images from CNN training loader...")
        display_augmented_images(train_loader_cnn, num_images=3, num_augmentations=3)
    else:
        print("Skipping Experiment 1: CNN training, due to invalid number of classes or empty train loader.")
else:
    print("Skipping Experiment 1: CNN training, as the data_root_example path is invalid or dataset not found.")

### 8.2. Experiment 2: Training `VGG19HandwritingModel` (Transfer Learning)

In [None]:
print("--- Experiment 2: Training VGG19HandwritingModel ---")

if os.path.exists(data_root_example):
    save_dir_vgg = 'model_checkpoints/vgg'
    if not os.path.exists(save_dir_vgg):
        os.makedirs(save_dir_vgg)
        print(f"Created checkpoint directory: {save_dir_vgg}")

    print(f"Initializing data pipeline for VGG experiment with data_root: {data_root_example}")
    vgg_pipeline = HandwritingDataPipeline(data_root=data_root_example, image_size=(64,64), batch_size=32, do_transform=True)
    train_loader_vgg, val_loader_vgg, test_loader_vgg = vgg_pipeline.get_loaders()
    num_classes_vgg = vgg_pipeline.num_classes
    num_epochs_vgg = 20

    if num_classes_vgg > 0 and len(train_loader_vgg.dataset) > 0:
        print(f"Number of classes for VGG: {num_classes_vgg}")
        print(f"Train samples: {len(train_loader_vgg.dataset)}, Val samples: {len(val_loader_vgg.dataset)}, Test samples: {len(test_loader_vgg.dataset)}")

        use_pretrained_vgg = True 
        model_vgg = VGG19HandwritingModel(num_classes=num_classes_vgg, device=device, pretrained=use_pretrained_vgg).to(device)
        print(f"VGG19 Model initialized {'with pretrained ImageNet weights' if use_pretrained_vgg else 'from scratch'}.")

        criterion_vgg = nn.CrossEntropyLoss()

        if use_pretrained_vgg:
            optimizer_vgg = optim.Adam([
                {'params': model_vgg.features.parameters(), 'lr': 1e-5},
                {'params': model_vgg.classifier.parameters(), 'lr': 1e-4}
            ], weight_decay=1e-4)
            scheduler_vgg = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vgg, T_max=num_epochs_vgg, eta_min=1e-6)
            print("Optimizer set up for pretrained VGG model with differential learning rates.")
        else:
            optimizer_vgg = optim.Adam(model_vgg.parameters(), lr=1e-3, weight_decay=1e-4)
            scheduler_vgg = torch.optim.lr_scheduler.OneCycleLR(optimizer_vgg, max_lr=1e-3, epochs=num_epochs_vgg, steps_per_epoch=len(train_loader_vgg))
            print("Optimizer set up for VGG model training from scratch.")

        print(f"Starting training for VGG19HandwritingModel for {num_epochs_vgg} epochs...")
        trained_model_vgg, history_vgg = train_model(model_vgg, train_loader_vgg, val_loader_vgg, 
                                                   criterion_vgg, optimizer_vgg, scheduler_vgg, 
                                                   num_epochs=num_epochs_vgg, save_dir=save_dir_vgg)

        print("\nTesting the final trained VGG19 model on the test set:")
        test_model(trained_model_vgg, test_loader_vgg, criterion_vgg)

        print("\nLoading the best saved VGG19 model and testing it on the test set:")
        best_model_vgg_instance = VGG19HandwritingModel(num_classes=num_classes_vgg, device=device, pretrained=False).to(device)
        dummy_optimizer_vgg = optim.Adam(best_model_vgg_instance.parameters())
        best_vgg_model_loaded, _, _, _, _, _ = load_model(best_model_vgg_instance, 
                                                       dummy_optimizer_vgg, 
                                                       os.path.join(save_dir_vgg, 'best_model.pth'))
        test_model(best_vgg_model_loaded, test_loader_vgg, criterion_vgg)
    else:
        print("Skipping Experiment 2: VGG training, due to invalid number of classes or empty train loader.")
else:
    print("Skipping Experiment 2: VGG training, as the data_root_example path is invalid or dataset not found.")