# Model Evaluation for Chest X-Ray Classification

This notebook handles the evaluation of the trained model. It calculates the accuracy, F1 score, and generates a confusion matrix to analyze the model's performance on the test dataset.

In [1]:
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import torch

In [2]:
def evaluate_model(model, test_dl):
     """
    Evaluate the model on the test dataset.
    
    Args:
        model (nn.Module): Trained model.
        test_dl (DataLoader): DataLoader for the test dataset.
    """
    device = torch.device("cuda")
    model.to(device)

    model.eval()
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for inputs, labels in test_dl:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Calculate accuracy and F1 score
    accuracy = accuracy_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels, average='weighted')

    print(f"Accuracy on Test Set: {accuracy}")
    print(f"F1 Score on Test Set: {f1}")

    # Calculate the confusion matrix
    cm = confusion_matrix(true_labels, predicted_labels)
    classes = ["COVID-19", "Lung-Opacity", "Normal", "Viral Pneumonia", "Tuberculosis"]

    # Create plot with manual axis configuration
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Plot heatmap WITHOUT Seaborn's automatic annotations
    sns.heatmap(cm, cmap='Blues', ax=ax, annot=False, cbar=False)
    
    # Manually set axis ticks and labels
    ax.set_xticks(np.arange(len(classes)) + 0.5)
    ax.set_xticklabels(classes, rotation=45, ha='right')
    ax.set_yticks(np.arange(len(classes)) + 0.5)
    ax.set_yticklabels(classes, rotation=0)

    # Add annotations manually to ensure they appear
    for i in range(len(classes)):
        for j in range(len(classes)):
            ax.text(j + 0.5, i + 0.5, str(cm[i, j]),
                    ha='center', va='center',
                    color='black' if cm[i, j] < cm.max()/2 else 'white')

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()