# Handwritten Character Dataset: Exploration, Preparation, and Augmentation

This notebook focuses on understanding and preparing the handwritten character dataset. It covers:
- Essential library imports and device setup.
- Custom data augmentation transformations.
- A data pipeline (`HandwritingDataPipeline`) for loading, transforming, and splitting the dataset.
- Showcasing combined data augmentations as applied by the pipeline.
- Demonstrating individual image transformations and their effects.
- Basic analysis of the dataset, including class distribution.

## 1. Imports and Setup

In [None]:
# General utilities
import os
import random
import collections

# Image processing and display
import matplotlib.pyplot as plt
from PIL import Image
import cv2 # OpenCV for image operations
import numpy as np

# PyTorch essentials
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Note: Ensure the 'data_root_example' variable later in this notebook 
# points to the correct path of your dataset for full functionality.

## 2. Device Configuration

In [None]:
# Device configuration
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

## 3. Custom Data Augmentation Transforms

These are custom PyTorch transforms used to augment the image data, helping to make the model more robust to variations in handwriting.

In [None]:
class RandomChoice(torch.nn.Module):
    """Randomly applies one of the given transforms with given probability"""
    def __init__(self, transforms, p=0.5):
        super().__init__()
        self.transforms = transforms
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            transform = random.choice(self.transforms)
            return transform(img)
        return img

class ThicknessTransform(torch.nn.Module):
    """Apply morphological operations to change stroke thickness.
    It randomly chooses between dilation (thicker) or erosion (thinner).
    Args:
        kernel_size (int): Size of the kernel for morphological operations (default: 3).
        iterations (int): Number of times to apply the operation (default: 1).
    """
    def __init__(self, kernel_size=3, iterations=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.iterations = iterations

    def __call__(self, img):
        # Convert PIL Image to OpenCV format (numpy array)
        img_cv = np.array(img)
        
        # Ensure image is grayscale for morphological operations
        if len(img_cv.shape) == 3 and img_cv.shape[2] == 3:
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY)
        elif len(img_cv.shape) == 3 and img_cv.shape[2] == 1: # Already grayscale but 3-channel
             img_cv = img_cv[:, :, 0]
        
        kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
        
        if random.random() > 0.5:
            # Dilation (thicker)
            processed_img = cv2.dilate(img_cv, kernel, iterations=self.iterations)
        else:
            # Erosion (thinner)
            processed_img = cv2.erode(img_cv, kernel, iterations=self.iterations)
        
        # Convert back to PIL Image
        return Image.fromarray(processed_img, mode='L') # 'L' for grayscale

## 4. Handwriting Data Pipeline

The `HandwritingDataPipeline` class encapsulates all steps for data loading, transformation, and splitting into training, validation, and test sets. It's designed to work with datasets structured in the ImageFolder format (where each subdirectory represents a class).

In [None]:
class HandwritingDataPipeline:
    def __init__(self, data_root, image_size=(64, 64), batch_size=32, do_transform=True, test_split=0.15, val_split=0.15):
        self.data_root = data_root
        self.image_size = image_size
        self.batch_size = batch_size
        self.do_transform = do_transform
        self.test_split = test_split
        self.val_split = val_split 
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        self._setup_transforms()
        self._load_and_split_datasets()

    def _setup_transforms(self):
        if self.do_transform:
            # This is the full augmentation pipeline used for training.
            # It assumes 3-channel output for compatibility with models like VGG.
            self.train_transform = transforms.Compose([
                transforms.Resize(self.image_size),
                transforms.Grayscale(num_output_channels=1),
                RandomChoice([
                    transforms.RandomAffine(degrees=20, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10, fill=255),
                    transforms.RandomPerspective(distortion_scale=0.3, p=0.5, fill=255),
                    transforms.RandomRotation(15, fill=255),
                ], p=0.8),
                ThicknessTransform(kernel_size=random.choice([1,2,3]), iterations=random.choice([1,2])),
                transforms.RandomApply([
                    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 0.5))
                ], p=0.3),
                transforms.ColorJitter(brightness=0.3, contrast=0.3),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
                self.normalize,
                transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 3.3), value='random')
            ])
        else:
            self.train_transform = transforms.Compose([
                transforms.Resize(self.image_size),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
                self.normalize
            ])

        self.val_test_transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x),
            self.normalize
        ])

    def _load_and_split_datasets(self):
        try:
            full_dataset = datasets.ImageFolder(root=self.data_root)
            self.class_names = full_dataset.classes
            self.num_classes = len(self.class_names)
        except FileNotFoundError:
            print(f"ERROR: Dataset not found at {self.data_root}. Please check the path.")
            self.class_names = []
            self.num_classes = 0
            self.train_dataset = Dataset() # Empty dataset
            self.val_dataset = Dataset()
            self.test_dataset = Dataset()
            self.sizes = {'train': 0, 'val': 0, 'test': 0}
            return

        total_size = len(full_dataset)
        if total_size == 0:
            print(f"ERROR: Dataset at {self.data_root} is empty.")
            self.train_dataset, self.val_dataset, self.test_dataset = Dataset(), Dataset(), Dataset()
            self.sizes = {'train': 0, 'val': 0, 'test': 0}
            return
            
        test_size = int(total_size * self.test_split)
        remaining_size = total_size - test_size
        val_size = int(remaining_size * (self.val_split / (1.0 - self.test_split)) if (1.0 - self.test_split) > 0 else 0)
        train_size = remaining_size - val_size

        if train_size <= 0 or val_size < 0 or test_size < 0: # val_size or test_size can be 0 for small datasets
            print(f"Warning: Dataset too small for current split ratios (Total: {total_size}). Adjusting...")
            if total_size < 3:
                # Use all data for all sets if very small, not ideal but prevents crashes
                train_dataset_subset, val_dataset_subset, test_dataset_subset = full_dataset, full_dataset, full_dataset
                self.sizes = {'train': total_size, 'val': total_size, 'test': total_size}
            else:
                # Prioritize training set, then validation, then test
                train_size = max(1, int(total_size * 0.7))
                val_size = max(1, int(total_size * 0.15))
                test_size = total_size - train_size - val_size
                if test_size < 0: test_size = 0 # Ensure non-negative
                
                # Perform the split with adjusted sizes
                train_temp_dataset, test_dataset_subset = torch.utils.data.random_split(full_dataset, [train_size + val_size, test_size],
                                                                              generator=torch.Generator().manual_seed(42))
                train_dataset_subset, val_dataset_subset = torch.utils.data.random_split(train_temp_dataset, [train_size, val_size],
                                                                           generator=torch.Generator().manual_seed(42))
        else:
            print(f"Attempting to split: Train={train_size}, Val={val_size}, Test={test_size}")
            try:
                train_temp_dataset, test_dataset_subset = torch.utils.data.random_split(full_dataset, [train_size + val_size, test_size],
                                                                              generator=torch.Generator().manual_seed(42))
                train_dataset_subset, val_dataset_subset = torch.utils.data.random_split(train_temp_dataset, [train_size, val_size],
                                                                           generator=torch.Generator().manual_seed(42))
            except Exception as e:
                print(f"Error during dataset splitting: {e}. Using full dataset for all.")
                train_dataset_subset, val_dataset_subset, test_dataset_subset = full_dataset, full_dataset, full_dataset

        self.train_dataset = TransformedDataset(train_dataset_subset, transform=self.train_transform)
        self.val_dataset = TransformedDataset(val_dataset_subset, transform=self.val_test_transform)
        self.test_dataset = TransformedDataset(test_dataset_subset, transform=self.val_test_transform)
        
        self.sizes = {'train': len(self.train_dataset), 'val': len(self.val_dataset), 'test': len(self.test_dataset)}

    def get_loaders(self, shuffle_train=True, shuffle_val=False, shuffle_test=False):
        if self.sizes['train'] == 0 and self.sizes['val'] == 0 and self.sizes['test'] == 0 and self.num_classes == 0:
            print("Pipeline not initialized properly or dataset empty/not found. Returning empty DataLoaders.")
            return DataLoader(Dataset()), DataLoader(Dataset()), DataLoader(Dataset())
            
        train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=shuffle_train, num_workers=0)
        val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=shuffle_val, num_workers=0)
        test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=shuffle_test, num_workers=0)
        return train_loader, val_loader, test_loader

    def get_class_labels(self):
        return self.class_names

class TransformedDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

### 4.1. Initialize Data Pipeline and Get Statistics

Set the `data_root_example` variable to the path of your dataset. The dataset should be organized in an ImageFolder structure (root_directory -> class_subdirectories -> images).

In [None]:
# <<< USER: CHANGE THIS PATH to your dataset's root folder >>>
data_root_example = "./datasets/handwritten-english/augmented_images/augmented_images1"

print(f"Attempting to initialize data pipeline with root: {data_root_example}")

if not os.path.exists(data_root_example) or not os.listdir(data_root_example):
    print(f"\nWARNING: The directory '{data_root_example}' does not exist or is empty.")
    print("Please ensure your dataset is available at this path and structured correctly (ImageFolder format).")
    print("Subsequent cells requiring data may fail or show limited results.")
    example_pipeline = None
    train_loader_example = None 
else:
    try:
        example_pipeline = HandwritingDataPipeline(data_root=data_root_example, image_size=(64,64), batch_size=16, do_transform=True)
        train_loader_example, _, _ = example_pipeline.get_loaders()
        dataset_sizes_example = example_pipeline.sizes
        num_classes_example = example_pipeline.num_classes
        class_labels_example = example_pipeline.get_class_labels()

        print(f"\nData pipeline initialized successfully.")
        print(f"Dataset split sizes: {dataset_sizes_example}")
        print(f"Number of classes: {num_classes_example}")
        if class_labels_example:
            print(f"First 10 Class labels: {class_labels_example[:10]}")
    except Exception as e:
        print(f"\nERROR initializing data pipeline or getting loaders: {e}")
        example_pipeline = None
        train_loader_example = None

### 4.2. Display Combined Augmented Images

This function visualizes the effect of the complete augmentation pipeline applied to training images. It fetches a batch of images from the training loader and displays a few of them, each with multiple augmented versions.

In [None]:
def display_augmented_images(data_loader, num_images=5, num_augmentations=3):
    """Displays original and augmented images from the train_loader.
    Assumes data_loader.dataset.subset.dataset is the ImageFolder-like dataset.
    """
    if data_loader is None or not hasattr(data_loader, 'dataset') or len(data_loader.dataset) == 0:
        print("Data loader is None or empty. Cannot display images.")
        return
    if not hasattr(data_loader.dataset, 'subset') or not hasattr(data_loader.dataset.subset, 'dataset') \
       or not hasattr(data_loader.dataset.subset.dataset, 'class_to_idx') \
       or not hasattr(data_loader.dataset.subset.dataset, 'imgs'):
        print("Data loader is not structured as expected (TransformedDataset -> Subset -> ImageFolder). Cannot display original image details.")
        # Fallback to just showing augmented images from the loader if structure is different
        try:
            inputs, _ = next(iter(data_loader))
            plt.figure(figsize=(num_augmentations * 2, num_images * 2))
            for i in range(min(num_images * num_augmentations, len(inputs))):
                ax = plt.subplot(num_images, num_augmentations, i + 1)
                img_display = inputs[i].cpu().numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
                img_display = std * img_display + mean; img_display = np.clip(img_display, 0, 1)
                plt.imshow(img_display.squeeze(), cmap='gray' if img_display.shape[2]==1 else None)
                ax.set_title(f'Aug. Img {i+1}')
                ax.axis('off')
            plt.tight_layout(); plt.show()
        except Exception as e_fallback:
            print(f"Fallback display failed: {e_fallback}")
        return

    imagefolder_dataset = data_loader.dataset.subset.dataset
    class_to_idx = imagefolder_dataset.class_to_idx
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    subset_indices = data_loader.dataset.subset.indices

    if len(subset_indices) < num_images:
        print(f"Warning: Requested {num_images} images, but dataset subset has {len(subset_indices)}. Displaying available.")
        num_images = len(subset_indices)
    if num_images == 0: print("No images to display from subset."); return

    random_subset_indices = random.sample(range(len(subset_indices)), num_images)
    fig = plt.figure(figsize=( (num_augmentations + 1) * 3, num_images * 3) )
    train_transform = data_loader.dataset.transform

    for i, random_idx_in_subset in enumerate(random_subset_indices):
        original_dataset_idx = subset_indices[random_idx_in_subset]
        original_path, true_label_idx = imagefolder_dataset.imgs[original_dataset_idx]
        class_name = idx_to_class[true_label_idx]
        original_pil = Image.open(original_path)

        ax = plt.subplot(num_images, num_augmentations + 1, i * (num_augmentations + 1) + 1)
        ax.imshow(original_pil.convert("RGB")); ax.set_title(f'Original: {class_name}'); ax.axis('off')

        for j in range(num_augmentations):
            augmented_tensor = train_transform(original_pil.copy())
            ax = plt.subplot(num_images, num_augmentations + 1, i * (num_augmentations + 1) + j + 2)
            img_display = augmented_tensor.cpu().numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
            img_display = std * img_display + mean; img_display = np.clip(img_display, 0, 1)
            if img_display.shape[2] == 1: plt.imshow(img_display.squeeze(), cmap='gray')
            else: plt.imshow(img_display)
            ax.set_title(f'Aug {j+1}: {class_name}'); ax.axis('off')
    plt.tight_layout(); plt.show()

if train_loader_example:
    print("Displaying augmented images from the initialized pipeline...")
    display_augmented_images(train_loader_example, num_images=4, num_augmentations=3)
else:
    print("Skipping display_augmented_images example as 'train_loader_example' is not available. Please check data_root_example.")

## 5. Individual Transformation Showcase

This section demonstrates the effect of individual transformations on a sample image. This helps in understanding how each augmentation technique contributes to data diversity.

In [None]:
def get_sample_image_path(data_root):
    """Tries to get a path to a sample image from the dataset."""
    if not data_root or not os.path.exists(data_root):
        return None
    try:
        full_dataset = datasets.ImageFolder(root=data_root)
        if len(full_dataset.samples) > 0:
            return full_dataset.samples[0][0] # Return path of the first sample
    except Exception as e:
        print(f"Could not load sample image from dataset: {e}")
    return None

sample_image_path = get_sample_image_path(data_root_example)

if not sample_image_path:
    print("Sample image path not found. Creating a dummy image for transformation showcase.")
    # Create a dummy image if no dataset image is available
    dummy_img_np = np.full((100, 100, 3), (200, 200, 200), dtype=np.uint8) # Light gray background
    cv2.putText(dummy_img_np, "A", (25, 75), cv2.FONT_HERSHEY_SIMPLEX, 3, (0,0,0), 5) # Black 'A'
    sample_pil_image = Image.fromarray(dummy_img_np)
else:
    try:
        sample_pil_image = Image.open(sample_image_path).convert("RGB")
    except Exception as e:
        print(f"Error opening sample image {sample_image_path}: {e}. Using dummy image.")
        dummy_img_np = np.full((100, 100, 3), (200, 200, 200), dtype=np.uint8)
        cv2.putText(dummy_img_np, "A", (25, 75), cv2.FONT_HERSHEY_SIMPLEX, 3, (0,0,0), 5)
        sample_pil_image = Image.fromarray(dummy_img_np)

def plot_transformed_image(original_img, transformed_img, title_original="Original", title_transformed="Transformed"):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(original_img)
    axes[0].set_title(title_original)
    axes[0].axis('off')
    axes[1].imshow(transformed_img, cmap='gray' if isinstance(transformed_img, np.ndarray) and transformed_img.ndim == 2 else None)
    axes[1].set_title(title_transformed)
    axes[1].axis('off')
    plt.show()

# Ensure sample_pil_image is not None before proceeding
if sample_pil_image is None:
    print("ERROR: sample_pil_image is None. Cannot proceed with individual transformation showcase.")
else:
    print("Showing individual transformations (ensure sample_pil_image was loaded):")
    # 1. Random Rotation
    rotator = transforms.RandomRotation(degrees=30, fill=255) # fill with white for rotation
    rotated_image = rotator(sample_pil_image.copy())
    plot_transformed_image(sample_pil_image, rotated_image, "Original", "Random Rotation (30 deg, fill white)")
    print("RandomRotation: Randomly rotates the image. `fill` handles empty areas after rotation.")

    # 2. Gaussian Blur
    blurrer = transforms.GaussianBlur(kernel_size=(5,9), sigma=(0.1, 5.0))
    blurred_image = blurrer(sample_pil_image.copy())
    plot_transformed_image(sample_pil_image, blurred_image, "Original", "Gaussian Blur")
    print("GaussianBlur: Blurs the image, can simulate out-of-focus or pen strokes.")

    # 3. ThicknessTransform (custom)
    # Ensure input is PIL, as ThicknessTransform expects it and converts to CV internally
    thickness_transformer = ThicknessTransform(kernel_size=3, iterations=1)
    thick_image = thickness_transformer(sample_pil_image.copy().convert('L')) # Convert to grayscale for this transform
    plot_transformed_image(sample_pil_image.convert('L'), thick_image, "Original Grayscale", "Thickness Transform")
    print("ThicknessTransform: Simulates variations in stroke thickness using dilation/erosion.")

    # 4. Canny Edge Detection (OpenCV)
    img_cv_for_canny = cv2.cvtColor(np.array(sample_pil_image.copy()), cv2.COLOR_RGB2GRAY)
    canny_edges = cv2.Canny(img_cv_for_canny, threshold1=100, threshold2=200)
    plot_transformed_image(img_cv_for_canny, canny_edges, "Original Grayscale", "Canny Edges")
    print("Canny Edge Detection: Detects strong edges. Can be useful for feature extraction or preprocessing.")

    # 5. Adaptive Thresholding (OpenCV)
    img_cv_for_thresh = cv2.cvtColor(np.array(sample_pil_image.copy()), cv2.COLOR_RGB2GRAY)
    adaptive_thresh_img = cv2.adaptiveThreshold(img_cv_for_thresh, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, \
                                                cv2.THRESH_BINARY_INV, 11, 2) # Inverted: white text on black bg
    plot_transformed_image(img_cv_for_thresh, adaptive_thresh_img, "Original Grayscale", "Adaptive Threshold (Inverted)")
    print("Adaptive Thresholding: Useful for binarizing images with varying illumination. Here, THRESH_BINARY_INV makes objects white.")

## 6. Dataset Analysis: Class Distribution

Understanding the distribution of images across different classes is crucial. Significant imbalances can affect model training, potentially biasing the model towards over-represented classes.

In [None]:
if example_pipeline and example_pipeline.num_classes > 0:
    print(f"Analyzing class distribution for dataset at: {data_root_example}")
    # Accessing the full dataset before splitting to get overall distribution
    # This assumes data_root_example is valid and points to an ImageFolder structure
    try:
        full_dataset_for_analysis = datasets.ImageFolder(root=data_root_example)
        class_counts = collections.Counter([sample[1] for sample in full_dataset_for_analysis.samples])
        class_names_analysis = [full_dataset_for_analysis.classes[i] for i in range(len(full_dataset_for_analysis.classes))]
        counts = [class_counts[i] for i in range(len(class_names_analysis))]

        plt.figure(figsize=(15, 7))
        plt.bar(class_names_analysis, counts)
        plt.xlabel("Class Label")
        plt.ylabel("Number of Images")
        plt.title("Class Distribution in the Dataset")
        plt.xticks(rotation=90, fontsize=8) # Rotate class labels for better readability if many classes
        plt.tight_layout()
        plt.show()

        # Discussion of imbalance
        mean_count = np.mean(counts)
        std_dev_count = np.std(counts)
        print(f"\nMean number of images per class: {mean_count:.2f}")
        print(f"Standard deviation of images per class: {std_dev_count:.2f}")
        if std_dev_count > mean_count * 0.5: # Arbitrary threshold for 'significant' imbalance
            print("There appears to be a notable class imbalance. This might affect model performance.")
            print("Consider techniques like weighted sampling, oversampling minority classes, or undersampling majority classes if performance is skewed.")
        else:
            print("Class distribution appears relatively balanced.")
            
    except Exception as e:
        print(f"Could not perform class distribution analysis: {e}")
else:
    print("Skipping dataset analysis as 'data_root_example' is not valid, or the pipeline was not initialized.")
    print("Please ensure 'data_root_example' points to a valid dataset to see the class distribution analysis.")