# Advanced Handwritten Character Recognition with EMNIST

This notebook provides a robust workflow for training a CNN model on the EMNIST dataset. Key features include:

1.  **Fast Training by Default**: Uses a 20% subset of the training data and runs for only 10 epochs.
2.  **Interrupt Handling**: Uses a custom callback to save training history (`loss`, `accuracy`) after every epoch. You can stop training at any time and still plot the results.
3.  **Improved CNN Model**: A more robust architecture with Dropout for better generalization.
4.  **Full Analysis**: Generates all necessary plots, a confusion matrix, and a classification report.

### 1. Imports and Custom Callback Definition

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import pickle
import pandas as pd

# Custom callback to save history and allow plotting even if training is interrupted
class HistoryCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.epoch = []
        self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epoch.append(epoch)
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)
        
        # Save history to a file after each epoch so we don't lose it
        with open('training_history.pkl', 'wb') as f:
            pickle.dump(self.history, f)

    def on_train_end(self, logs=None):
         print("Training finished. History is available in `history_callback.history` and `training_history.pkl`")

### 2. Load and Preprocess the EMNIST Dataset

We will use the `byclass` split which has 62 classes (0-9, A-Z, a-z). By default, we load only **20% of the training data** and **50% of the test data** to speed things up. You can change this by modifying the `split` argument.

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'emnist/byclass',
    # To use the full dataset, change this to: split=['train', 'test']
    split=['train[:20%]', 'test[:50%]'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

# Function to normalize and reshape images
def normalize_img(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    # EMNIST images are rotated and flipped, we need to fix them
    image = tf.transpose(image, perm=[1, 0, 2])
    return image, label

# Prepare the datasets
BATCH_SIZE = 128
ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

print(f"Number of classes: {ds_info.features['label'].num_classes}")

### 3. Build the CNN Model

In [None]:
def build_model(num_classes):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Dropout(0.25), # Dropout for regularization
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(0.5), # Dropout for regularization
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

model = build_model(ds_info.features['label'].num_classes)
model.summary()

### 4. Train the Model

We will train for **10 epochs**. You can stop the training at any time, and the plotting in the next cell will still work.

In [None]:
history_callback = HistoryCallback()

# The 'history' object returned by fit() will be empty.
# All data is stored in our custom callback.
model.fit(
    ds_train,
    epochs=10, # Reduced for faster results
    validation_data=ds_test,
    callbacks=[history_callback]
)

### 5. Plot Training History

This cell will work even if you stopped the training manually.

In [None]:
# Use the history from our custom callback
history_data = history_callback.history

# If you restarted the notebook, you can load the history from the file
# with open('training_history.pkl', 'rb') as f:
#     history_data = pickle.load(f)

# Plotting Accuracy
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history_data['accuracy'], label='Training Accuracy')
plt.plot(history_data['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plotting Loss
plt.subplot(1, 2, 2)
plt.plot(history_data['loss'], label='Training Loss')
plt.plot(history_data['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

### 6. Evaluate Model and Save

Finally, we evaluate the model on the test set to get the final accuracy and generate a confusion matrix and classification report.

In [None]:
# Evaluate the model
loss, accuracy = model.evaluate(ds_test)
print(f'Final Test Accuracy: {accuracy*100:.2f}%')

# Get predictions
y_pred_probs = model.predict(ds_test)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.concatenate([y for x, y in ds_test], axis=0)

# Define class labels
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

# Classification Report
print('\nClassification Report:')
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(20, 20))
sns.heatmap(cm, annot=False, xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

# Save the final model
model.save('htr_model.h5')
print('\nModel saved as htr_model.h5')