# EMNIST Handwritten Character Recognition - Advanced Training

This notebook trains a more robust Convolutional Neural Network (CNN) on the EMNIST `byclass` dataset with the following improvements:

1.  **Increased Data**: Uses **50%** of the training dataset.
2.  **Data Augmentation**: Applies real-time image augmentation (rotation, shift, shear, zoom) to make the model more robust to variations in handwriting.
3.  **Deeper Model**: Uses a more complex CNN architecture with more filters and dropout layers to capture finer details.
4.  **Interrupt-Safe Training**: Includes callbacks to save training history and the best model weights after each epoch, so progress is not lost if training is stopped.
5.  **Comprehensive Evaluation**: Generates and saves plots for training history, a confusion matrix, and a detailed classification report.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint

## 1. Interrupt-Safe History Callback

This custom callback saves the training history (loss, accuracy, etc.) to a file after every single epoch. If training is interrupted, we can still load this file and plot our progress.

In [None]:
class HistoryCallback(tf.keras.callbacks.Callback):
    """Callback to save history after each epoch."""
    def on_train_begin(self, logs=None):
        self.epoch = []
        self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epoch.append(epoch)
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)
        # Save history to a pickle file after each epoch
        with open('training_history.pkl', 'wb') as f:
            pickle.dump(self.history, f)

    def on_train_end(self, logs=None):
         print("Training finished. History is saved in 'training_history.pkl'")

## 2. Load and Prepare EMNIST Dataset

We load the `emnist/byclass` dataset, which contains 62 classes (digits, uppercase, and lowercase letters). We will use **50% of the training data** and 50% of the test data to speed up the process while improving accuracy.

In [None]:
# Load 50% of the training data and 50% of the test data
(ds_train, ds_test), ds_info = tfds.load(
    'emnist/byclass',
    split=['train[:50%]', 'test[:50%]'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

num_classes = ds_info.features['label'].num_classes
print(f'Number of classes: {num_classes}')

# Define the human-readable labels
labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
print(f'Total labels: {len(labels)}')

### Preprocess and Batch the Data

We normalize the pixel values to be between 0 and 1. The EMNIST dataset images are also transposed (rotated and flipped), so we fix that here. We also convert the datasets into numpy arrays for use with the data generator.

In [None]:
def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    # EMNIST images are rotated and flipped, we need to transpose them back
    image = tf.transpose(image, perm=[1, 0, 2])
    return image, label

ds_train = ds_train.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)

# Convert datasets to numpy arrays for ImageDataGenerator
ds_train_images = np.array([x for x, y in ds_train])
ds_train_labels = np.array([y for x, y in ds_train])
ds_test_images = np.array([x for x, y in ds_test])
ds_test_labels = np.array([y for x, y in ds_test])

print(f'Training images shape: {ds_train_images.shape}')
print(f'Test images shape: {ds_test_images.shape}')

## 3. Data Augmentation

We use `ImageDataGenerator` to create modified versions of our training images on-the-fly. This helps the model generalize better to real-world handwriting that might be slightly rotated, shifted, or scaled.

In [None]:
datagen = ImageDataGenerator(
    rotation_range=10,      # randomly rotate images in the range (degrees, 0 to 180)
    width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
    height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
    shear_range=0.1,        # set range for random shear
    zoom_range=0.1,         # set range for random zoom
    horizontal_flip=False,  # EMNIST chars are not flipped
    vertical_flip=False     # EMNIST chars are not flipped
)

# Fit the generator on our training data
datagen.fit(ds_train_images)

## 4. Build the Improved CNN Model

This model is deeper and wider than the previous version. It has more convolutional filters to learn more complex features and uses dropout to prevent overfitting.

In [None]:
model = tf.keras.models.Sequential([
    # Increased filters in the first convolutional layer
    tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(2, 2),
    
    # Increased filters in the second convolutional layer
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Dropout(0.3), # Added dropout for regularization
    
    tf.keras.layers.Flatten(),
    
    # Increased density in the dense layer
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5), # Increased dropout
    
    # Output layer with softmax for multi-class classification
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

## 5. Train the Model

We train for 10 epochs. We use two callbacks:
1.  `HistoryCallback`: Saves our metrics plot data.
2.  `ModelCheckpoint`: Saves the best version of the model seen so far during training. This is our primary defense against losing work if the session is interrupted.

In [None]:
history_callback = HistoryCallback()

# This callback will save the model with the best validation accuracy
checkpoint_callback = ModelCheckpoint(
    filepath='best_model.h5',      # File to save the model
    save_best_only=True,         # Only save a model if `val_accuracy` has improved
    monitor='val_accuracy',      # Monitor validation accuracy
    mode='max',                  # The higher the val_accuracy, the better
    verbose=1                    # Print a message when the model is saved
)

print("Starting model training...")
history = model.fit(
    datagen.flow(ds_train_images, ds_train_labels, batch_size=256), # Use the data generator
    epochs=10,
    validation_data=(ds_test_images, ds_test_labels), # Use the original test set for validation
    callbacks=[history_callback, checkpoint_callback]
)
print("--- Model training complete ---")

## 6. Analyze Training Results

First, we load the training history from our pickle file. This ensures we can generate these plots even if the training was stopped and the `history` object in the notebook was lost.

In [None]:
# Load the history from the pickle file for robust plotting
with open('training_history.pkl', 'rb') as f:
    history_data = pickle.load(f)

# Create a DataFrame for easy plotting
history_df = pd.DataFrame(history_data)
history_df['epoch'] = history_df.index + 1

plt.figure(figsize=(14, 5))

# Plot Training & Validation Accuracy
plt.subplot(1, 2, 1)
plt.plot(history_df['epoch'], history_df['accuracy'], label='Training Accuracy', marker='o')
plt.plot(history_df['epoch'], history_df['val_accuracy'], label='Validation Accuracy', marker='o')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# Plot Training & Validation Loss
plt.subplot(1, 2, 2)
plt.plot(history_df['epoch'], history_df['loss'], label='Training Loss', marker='o')
plt.plot(history_df['epoch'], history_df['val_loss'], label='Validation Loss', marker='o')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()

# Save the plot to a file
plot_filename = 'training_history.png'
plt.savefig(plot_filename)
print(f"Saved training history plot to '{plot_filename}'")
plt.show()

## 7. Evaluate the Best Model on the Test Set

Now, we load the **best model** that was saved by our `ModelCheckpoint` callback. This ensures we are evaluating the version of the model that performed best on the validation data, not just the one from the final epoch.

In [None]:
# Load the best model saved during training
print("Loading the best model from 'best_model.h5'...")
best_model = tf.keras.models.load_model('best_model.h5')

# Make predictions on the test set
y_pred_probs = best_model.predict(ds_test_images)
y_pred = np.argmax(y_pred_probs, axis=1)

# Generate and print the classification report
print("\n--- Classification Report ---")
report = classification_report(ds_test_labels, y_pred, target_names=labels)
print(report)

# Save the report to a text file
report_filename = 'classification_report.txt'
with open(report_filename, 'w') as f:
    f.write(report)
print(f"Saved classification report to '{report_filename}'")

### Confusion Matrix

The confusion matrix gives us a detailed, visual breakdown of which classes the model is confusing with others.

In [None]:
# Generate the confusion matrix
conf_matrix = confusion_matrix(ds_test_labels, y_pred)

plt.figure(figsize=(20, 20))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')

# Save the confusion matrix plot to a file
cm_filename = 'confusion_matrix.png'
plt.savefig(cm_filename)
print(f"Saved confusion matrix to '{cm_filename}'")
plt.show()

## 8. Save the Final Model for the Web App

Finally, we save the best-performing model with the standard name `htr_model.h5` so it can be easily used by the Flask application.

In [None]:
print("\n[8/8] Saving the final, best-performing model...")
best_model.save('htr_model.h5')
print("--- Model training and evaluation complete. Best model saved as htr_model.h5 ---")