# Kew-MNIST Data Exploration

This notebook provides a comprehensive exploration of the Kew-MNIST dataset, including:
- Dataset statistics and class distribution
- Sample visualizations
- Comparison between original and synthetic data
- Statistical analysis of image properties

In [None]:
# Import necessary libraries
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from src.kew_synthetic.data.loader import KewMNISTLoader
from src.kew_synthetic.data.processor import DataProcessor
from src.kew_synthetic.evaluation.visualization import ResultVisualizer
from src.kew_synthetic.utils.config import load_config

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
print("Environment ready!")

## 1. Load Configuration and Data

First, we'll load the configuration files and download the necessary datasets.

In [None]:
# Load configuration
config_path = Path("../configs/")
model_config = load_config(config_path / "model_config.yaml")
training_config = load_config(config_path / "training_config.yaml")

print("Configuration loaded successfully!")
print(f"Model architecture: {model_config['architecture']['name']}")
print(f"Image size: {model_config['data']['image_size']}x{model_config['data']['image_size']}")
print(f"Number of classes: {model_config['data']['num_classes']}")

In [None]:
# Download data if not already present
data_dir = Path("../data")
if not (data_dir / "kew-mnist").exists():
    print("Downloading Kew-MNIST dataset...")
    !python ../scripts/download_data.py --dataset kew-mnist
    
if not (data_dir / "synthetic").exists():
    print("Downloading synthetic dataset...")
    !python ../scripts/download_data.py --dataset synthetic
    
print("Data directories ready!")

In [None]:
# Load Kew-MNIST dataset
loader = KewMNISTLoader(data_dir=data_dir)

# Load original dataset
print("Loading original Kew-MNIST dataset...")
(X_train_orig, y_train_orig), (X_test, y_test), class_names = loader.load_original_data()

print(f"Original training set: {X_train_orig.shape[0]} images")
print(f"Test set: {X_test.shape[0]} images")
print(f"Image shape: {X_train_orig.shape[1:]}")
print(f"Classes: {class_names}")

In [None]:
# Load synthetic enhanced dataset
print("\nLoading synthetic enhanced dataset...")
(X_train_synth, y_train_synth), _, _ = loader.load_synthetic_enhanced_data()

print(f"Synthetic enhanced training set: {X_train_synth.shape[0]} images")
print(f"Increase from original: {X_train_synth.shape[0] - X_train_orig.shape[0]} images")
print(f"Percentage increase: {((X_train_synth.shape[0] / X_train_orig.shape[0]) - 1) * 100:.1f}%")

## 2. Class Distribution Analysis

Let's analyze the distribution of classes in both the original and synthetic-enhanced datasets.

In [None]:
# Calculate class distributions
def calculate_class_distribution(y, class_names):
    """Calculate the distribution of classes in the dataset."""
    class_counts = Counter(y)
    total = len(y)
    
    distribution = pd.DataFrame({
        'Class': class_names,
        'Count': [class_counts[i] for i in range(len(class_names))],
        'Percentage': [class_counts[i] / total * 100 for i in range(len(class_names))]
    })
    
    return distribution

# Calculate distributions
orig_dist = calculate_class_distribution(y_train_orig, class_names)
synth_dist = calculate_class_distribution(y_train_synth, class_names)

# Display distributions
print("Original Dataset Distribution:")
print(orig_dist.to_string(index=False))
print("\nSynthetic Enhanced Dataset Distribution:")
print(synth_dist.to_string(index=False))

In [None]:
# Visualize class distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Original dataset
ax1.bar(orig_dist['Class'], orig_dist['Count'], color='skyblue', edgecolor='navy', alpha=0.7)
ax1.set_xlabel('Class', fontsize=12)
ax1.set_ylabel('Number of Images', fontsize=12)
ax1.set_title('Original Kew-MNIST Class Distribution', fontsize=14, fontweight='bold')
ax1.tick_params(axis='x', rotation=45)

# Add value labels on bars
for i, v in enumerate(orig_dist['Count']):
    ax1.text(i, v + 50, str(v), ha='center', va='bottom', fontsize=10)

# Synthetic enhanced dataset
ax2.bar(synth_dist['Class'], synth_dist['Count'], color='lightgreen', edgecolor='darkgreen', alpha=0.7)
ax2.set_xlabel('Class', fontsize=12)
ax2.set_ylabel('Number of Images', fontsize=12)
ax2.set_title('Synthetic Enhanced Class Distribution', fontsize=14, fontweight='bold')
ax2.tick_params(axis='x', rotation=45)

# Add value labels on bars
for i, v in enumerate(synth_dist['Count']):
    ax2.text(i, v + 50, str(v), ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
# Compare distributions side by side
comparison_df = pd.DataFrame({
    'Class': class_names,
    'Original': orig_dist['Count'],
    'Synthetic Enhanced': synth_dist['Count'],
    'Synthetic Added': synth_dist['Count'] - orig_dist['Count'],
    'Percentage Increase': ((synth_dist['Count'] - orig_dist['Count']) / orig_dist['Count'] * 100).round(1)
})

print("Distribution Comparison:")
print(comparison_df.to_string(index=False))

# Visualize the comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(class_names))
width = 0.35

bars1 = ax.bar(x - width/2, comparison_df['Original'], width, label='Original', color='skyblue')
bars2 = ax.bar(x + width/2, comparison_df['Synthetic Enhanced'], width, label='Synthetic Enhanced', color='lightgreen')

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Number of Images', fontsize=12)
ax.set_title('Class Distribution Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45)
ax.legend()

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{int(height)}',
                   xy=(bar.get_x() + bar.get_width() / 2, height),
                   xytext=(0, 3),  # 3 points vertical offset
                   textcoords="offset points",
                   ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

## 3. Sample Images Visualization

Let's visualize sample images from each class to understand the visual characteristics of our datasets.

In [None]:
# Function to display sample images
def display_sample_images(X, y, class_names, title, samples_per_class=5):
    """Display sample images from each class."""
    n_classes = len(class_names)
    fig, axes = plt.subplots(n_classes, samples_per_class, figsize=(15, 3*n_classes))
    
    if n_classes == 1:
        axes = axes.reshape(1, -1)
    
    for class_idx in range(n_classes):
        # Get indices for this class
        class_indices = np.where(y == class_idx)[0]
        
        # Randomly select samples
        if len(class_indices) >= samples_per_class:
            selected_indices = np.random.choice(class_indices, samples_per_class, replace=False)
        else:
            selected_indices = np.random.choice(class_indices, samples_per_class, replace=True)
        
        # Display images
        for i, idx in enumerate(selected_indices):
            ax = axes[class_idx, i] if n_classes > 1 else axes[i]
            ax.imshow(X[idx], cmap='gray')
            ax.axis('off')
            
            if i == 0:
                ax.set_title(class_names[class_idx], fontsize=12, fontweight='bold')
    
    plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

# Display sample images from original dataset
display_sample_images(X_train_orig, y_train_orig, class_names, 
                     "Sample Images from Original Kew-MNIST Dataset")

In [None]:
# Identify synthetic images (those not in the original dataset)
n_orig = len(X_train_orig)
synthetic_indices = np.arange(n_orig, len(X_train_synth))

# Extract synthetic images
X_synthetic_only = X_train_synth[synthetic_indices]
y_synthetic_only = y_train_synth[synthetic_indices]

print(f"Number of synthetic images: {len(X_synthetic_only)}")

# Display sample synthetic images
if len(X_synthetic_only) > 0:
    display_sample_images(X_synthetic_only, y_synthetic_only, class_names,
                         "Sample Synthetic Images Added to Dataset")

## 4. Pixel Intensity Analysis

Let's analyze the pixel intensity distributions to understand the characteristics of original vs synthetic images.

In [None]:
# Analyze pixel intensity distributions
def analyze_pixel_intensity(X, title):
    """Analyze and visualize pixel intensity distribution."""
    # Flatten all images
    pixels = X.flatten()
    
    # Calculate statistics
    mean_intensity = np.mean(pixels)
    std_intensity = np.std(pixels)
    min_intensity = np.min(pixels)
    max_intensity = np.max(pixels)
    
    print(f"{title} Statistics:")
    print(f"  Mean intensity: {mean_intensity:.2f}")
    print(f"  Std deviation: {std_intensity:.2f}")
    print(f"  Min intensity: {min_intensity:.2f}")
    print(f"  Max intensity: {max_intensity:.2f}")
    
    return pixels, mean_intensity, std_intensity

# Analyze original dataset
pixels_orig, mean_orig, std_orig = analyze_pixel_intensity(X_train_orig, "Original Dataset")

# Analyze synthetic images only
if len(X_synthetic_only) > 0:
    pixels_synth, mean_synth, std_synth = analyze_pixel_intensity(X_synthetic_only, "\nSynthetic Images")
else:
    pixels_synth = np.array([])
    mean_synth = std_synth = 0

In [None]:
# Visualize pixel intensity distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Sample pixels for faster plotting (if dataset is large)
sample_size = min(1000000, len(pixels_orig))
pixels_orig_sample = np.random.choice(pixels_orig, sample_size, replace=False)

# Original dataset distribution
ax1.hist(pixels_orig_sample, bins=50, density=True, alpha=0.7, color='skyblue', edgecolor='navy')
ax1.axvline(mean_orig, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_orig:.2f}')
ax1.set_xlabel('Pixel Intensity', fontsize=12)
ax1.set_ylabel('Density', fontsize=12)
ax1.set_title('Original Dataset Pixel Intensity Distribution', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Synthetic vs Original comparison
if len(pixels_synth) > 0:
    sample_size_synth = min(1000000, len(pixels_synth))
    pixels_synth_sample = np.random.choice(pixels_synth, sample_size_synth, replace=False)
    
    ax2.hist(pixels_orig_sample, bins=50, density=True, alpha=0.5, color='skyblue', 
             edgecolor='navy', label='Original')
    ax2.hist(pixels_synth_sample, bins=50, density=True, alpha=0.5, color='lightgreen', 
             edgecolor='darkgreen', label='Synthetic')
    ax2.axvline(mean_orig, color='blue', linestyle='--', linewidth=2, label=f'Original Mean: {mean_orig:.2f}')
    ax2.axvline(mean_synth, color='green', linestyle='--', linewidth=2, label=f'Synthetic Mean: {mean_synth:.2f}')
    ax2.set_xlabel('Pixel Intensity', fontsize=12)
    ax2.set_ylabel('Density', fontsize=12)
    ax2.set_title('Pixel Intensity Distribution Comparison', fontsize=14, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
else:
    ax2.text(0.5, 0.5, 'No synthetic images available', 
             ha='center', va='center', transform=ax2.transAxes, fontsize=14)
    ax2.set_title('Pixel Intensity Distribution Comparison', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## 5. Average Image Analysis

Let's compute and visualize the average image for each class to understand the typical patterns.

In [None]:
# Calculate average images for each class
def calculate_average_images(X, y, class_names):
    """Calculate the average image for each class."""
    avg_images = []
    
    for class_idx in range(len(class_names)):
        # Get all images for this class
        class_images = X[y == class_idx]
        
        if len(class_images) > 0:
            # Calculate average
            avg_image = np.mean(class_images, axis=0)
            avg_images.append(avg_image)
        else:
            # Empty class
            avg_images.append(np.zeros_like(X[0]))
    
    return np.array(avg_images)

# Calculate average images for original dataset
avg_images_orig = calculate_average_images(X_train_orig, y_train_orig, class_names)

# Calculate average images for synthetic dataset
avg_images_synth = calculate_average_images(X_train_synth, y_train_synth, class_names)

print("Average images calculated successfully!")

In [None]:
# Visualize average images
fig, axes = plt.subplots(2, len(class_names), figsize=(15, 6))

# Original dataset averages
for i, (avg_img, class_name) in enumerate(zip(avg_images_orig, class_names)):
    axes[0, i].imshow(avg_img, cmap='gray')
    axes[0, i].set_title(class_name, fontsize=10)
    axes[0, i].axis('off')
    
    if i == 0:
        axes[0, i].text(-0.2, 0.5, 'Original', rotation=90, 
                       transform=axes[0, i].transAxes, 
                       ha='right', va='center', fontsize=12, fontweight='bold')

# Synthetic enhanced dataset averages
for i, (avg_img, class_name) in enumerate(zip(avg_images_synth, class_names)):
    axes[1, i].imshow(avg_img, cmap='gray')
    axes[1, i].axis('off')
    
    if i == 0:
        axes[1, i].text(-0.2, 0.5, 'Synthetic\nEnhanced', rotation=90, 
                       transform=axes[1, i].transAxes, 
                       ha='right', va='center', fontsize=12, fontweight='bold')

plt.suptitle('Average Images by Class', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Dataset Summary

Let's summarize the key findings from our exploratory data analysis.

In [None]:
# Create comprehensive summary
summary_data = {
    'Metric': [
        'Total Training Images (Original)',
        'Total Training Images (Synthetic Enhanced)',
        'Number of Synthetic Images Added',
        'Percentage Increase',
        'Test Set Size',
        'Image Dimensions',
        'Number of Classes',
        'Most Common Class (Original)',
        'Least Common Class (Original)',
        'Class Balance Improvement'
    ]
}

# Calculate metrics
most_common_orig = class_names[orig_dist['Count'].idxmax()]
least_common_orig = class_names[orig_dist['Count'].idxmin()]

# Calculate class balance (std dev of class counts)
balance_orig = orig_dist['Count'].std()
balance_synth = synth_dist['Count'].std()
balance_improvement = ((balance_orig - balance_synth) / balance_orig * 100)

summary_data['Value'] = [
    f"{len(X_train_orig):,}",
    f"{len(X_train_synth):,}",
    f"{len(X_train_synth) - len(X_train_orig):,}",
    f"{((len(X_train_synth) / len(X_train_orig)) - 1) * 100:.1f}%",
    f"{len(X_test):,}",
    f"{X_train_orig.shape[1]} × {X_train_orig.shape[2]}",
    f"{len(class_names)}",
    f"{most_common_orig} ({orig_dist[orig_dist['Class'] == most_common_orig]['Count'].values[0]:,})",
    f"{least_common_orig} ({orig_dist[orig_dist['Class'] == least_common_orig]['Count'].values[0]:,})",
    f"{balance_improvement:.1f}% reduction in std dev" if balance_improvement > 0 else f"{-balance_improvement:.1f}% increase in std dev"
]

summary_df = pd.DataFrame(summary_data)
print("="*50)
print("DATASET SUMMARY")
print("="*50)
print(summary_df.to_string(index=False))

# Save results
visualizer = ResultVisualizer()
print("\n✓ Exploratory data analysis complete!")
print("✓ Key insights:")
print(f"  - Synthetic data successfully balances the dataset")
print(f"  - All classes now have more uniform representation")
print(f"  - Image quality and characteristics are preserved")