In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import json

In [None]:
# Load the saved model
model = tf.keras.models.load_model('model/saved_model')


In [None]:
# Load class indices
with open('model/class_indices.json', 'r') as f:
    class_indices = json.load(f)

In [None]:
# Invert the dictionary to map indices to class names
indices_to_classes = {v: k for k, v in class_indices.items()}


In [None]:
# Evaluate on test data
test_loss, test_acc = model.evaluate(test_generator)
print(f'Test accuracy: {test_acc:.4f}')
print(f'Test loss: {test_loss:.4f}')

In [None]:
# Get predictions
test_generator.reset()
y_pred = model.predict(test_generator)
y_pred_classes = np.argmax(y_pred, axis=1)


In [None]:
# Get true labels
y_true = test_generator.classes

In [None]:
# Print classification report
print("\nClassification Report:")
print(classification_report(
    y_true, 
    y_pred_classes, 
    target_names=list(class_indices.keys())
))

In [None]:
# Create confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=list(class_indices.keys()),
    yticklabels=list(class_indices.keys())
)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.savefig('model/confusion_matrix.png')
plt.show()

In [None]:
# Visualize some predictions
def plot_predictions(model, test_generator, num_images=6):
    test_generator.reset()
    batch = next(test_generator)
    images, labels = batch
    
    predictions = model.predict(images)
    pred_classes = np.argmax(predictions, axis=1)
    true_classes = np.argmax(labels, axis=1)
    
    plt.figure(figsize=(15, 10))
    for i in range(min(num_images, len(images))):
        plt.subplot(2, 3, i+1)
        plt.imshow(images[i])
        plt.title(f"True: {indices_to_classes[true_classes[i]]}\nPred: {indices_to_classes[pred_classes[i]]}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('model/sample_predictions.png')
    plt.show()

plot_predictions(model, test_generator)