In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(model, X_test, y_test, label_encoder, ordered_labels, title='Confusion Matrix'):
    """
    Plot a confusion matrix with red boxes around the diagonal.

    Parameters:
        model: Trained Keras model
        X_test: Test feature data
        y_test: One-hot encoded true labels
        label_encoder: sklearn LabelEncoder instance (fitted)
        ordered_labels: List of labels in desired display order
        title: Title for the plot
    """
    # Predict class probabilities on the test set
    y_pred_probs = model.predict(X_test)
    y_pred_classes = y_pred_probs.argmax(axis=1)  # Take class with highest probability

    # Decode true and predicted labels back to string labels
    y_test_labels = label_encoder.inverse_transform(y_test.argmax(axis=1))
    y_pred_labels = label_encoder.inverse_transform(y_pred_classes)

    # Compute confusion matrix
    cm = confusion_matrix(y_test_labels, y_pred_labels, labels=ordered_labels)

    # Create the plot
    fig, ax = plt.subplots(figsize=(8, 6))
    cax = ax.matshow(cm, cmap='Blues')  # Display confusion matrix as a color map
    fig.colorbar(cax)

    # Set axis ticks and labels
    ax.set_xticks(np.arange(len(ordered_labels)))
    ax.set_yticks(np.arange(len(ordered_labels)))
    ax.set_xticklabels(ordered_labels, rotation=45)
    ax.set_yticklabels(ordered_labels)
    ax.set_xlabel('Predicted Label')
    ax.set_ylabel('True Label')
    ax.set_title(title, pad=20)

    # Annotate each cell with its count and highlight the diagonal
    for i in range(len(ordered_labels)):
        for j in range(len(ordered_labels)):
            ax.text(j, i, str(cm[i, j]),
                    va='center', ha='center',
                    color='white' if cm[i, j] > cm.max()/2 else 'black')
            # Highlight correct predictions with a red box
            if i == j:
                rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, fill=False,
                                     edgecolor='red', linewidth=2)
                ax.add_patch(rect)

    plt.tight_layout()
    plt.show()
