In [5]:
import os
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
import time
import sys
from tqdm import tqdm

# Import your ensemble model functions
from ensemble2 import get_device, transform, load_resnet, load_efficientnet, load_vit, ensemble_predict

In [6]:

def test_ensemble_model(test_data_path):
    """
    Test the ensemble model on local test data
    
    Args:
        test_data_path (str): Path to the test data directory
    """
    # Ensure the test data directory exists
    if not os.path.exists(test_data_path):
        print(f"Error: Test data directory '{test_data_path}' not found.")
        return
    
    # Set up paths - look for cancer and non_cancer directories or use them if they're specified directly
    base_path = os.path.dirname(test_data_path) if os.path.basename(test_data_path) in ['cancer', 'non_cancer'] else test_data_path
    
    cancer_dir = os.path.join(base_path, 'cancer')
    non_cancer_dir = os.path.join(base_path, 'non_cancer')
    
    # Check that the directories exist
    if not os.path.exists(cancer_dir):
        print(f"Error: Cancer directory '{cancer_dir}' not found.")
        return
    
    if not os.path.exists(non_cancer_dir):
        print(f"Error: Non-cancer directory '{non_cancer_dir}' not found.")
        return
    
    # Get list of image files
    cancer_files = [os.path.join(cancer_dir, f) for f in os.listdir(cancer_dir) 
                    if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    non_cancer_files = [os.path.join(non_cancer_dir, f) for f in os.listdir(non_cancer_dir) 
                        if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    
    print(f"Found {len(cancer_files)} cancer images and {len(non_cancer_files)} non-cancer images.")
    
    # Prepare data for testing
    all_files = cancer_files + non_cancer_files
    true_labels = ([0] * len(cancer_files)) + ([1] * len(non_cancer_files))  # 0 for cancer, 1 for non-cancer
    
    # Initialize predictions list
    predictions = []
    confidences = []
    
    # Time tracking
    start_time = time.time()
    
    # Test each image
    for i, img_path in enumerate(tqdm(all_files, desc="Testing images")):
        try:
            # Load and process image
            img = Image.open(img_path).convert("RGB")
            
            # Get prediction from ensemble model
            result = ensemble_predict(img, img_path)
            
            # Record prediction (convert to 0/1 format)
            pred_label = 1 if result["prediction"] == "Non-Cancerous" else 0
            predictions.append(pred_label)
            confidences.append(result["confidence"])
            
        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            # If error, assume non-cancerous (this is arbitrary and could be changed)
            predictions.append(1)
            confidences.append(0.0)
    
    # Calculate time taken
    total_time = time.time() - start_time
    avg_time_per_image = total_time / len(all_files)
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    
    # Convert lists to numpy arrays for confusion matrix
    true_labels_np = np.array(true_labels)
    predictions_np = np.array(predictions)
    
    # Calculate confusion matrix
    cm = confusion_matrix(true_labels_np, predictions_np)
    
    # Calculate class-wise accuracy
    cancer_correct = sum(1 for i in range(len(cancer_files)) if predictions[i] == true_labels[i])
    non_cancer_correct = sum(1 for i in range(len(cancer_files), len(all_files)) if predictions[i] == true_labels[i])
    
    # Print results
    print("\n===== ENSEMBLE MODEL EVALUATION RESULTS =====")
    print(f"Total images tested: {len(all_files)}")
    print(f"Cancer images: {len(cancer_files)}")
    print(f"Non-cancer images: {len(non_cancer_files)}")
    print("\nPerformance Metrics:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"\nAverage prediction time per image: {avg_time_per_image:.4f} seconds")
    
    # Create a results directory if it doesn't exist
    results_dir = 'model_evaluation_results'
    os.makedirs(results_dir, exist_ok=True)
    
    # Display confusion matrix
    class_names = ['Cancer', 'Non-Cancer']
    plt.figure(figsize=(10, 8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    cm_path = os.path.join(results_dir, 'confusion_matrix.png')
    plt.savefig(cm_path)
    plt.close()
    
    # Plot performance metrics
    plt.figure(figsize=(10, 6))
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
    values = [accuracy, precision, recall, f1]
    colors = ['#4CAF50', '#2196F3', '#FF9800', '#9C27B0']
    
    plt.bar(metrics, values, color=colors)
    plt.ylim(0, 1.0)
    plt.title('Model Performance Metrics')
    plt.ylabel('Score')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels on top of bars
    for i, v in enumerate(values):
        plt.text(i, v + 0.02, f'{v:.4f}', ha='center', fontweight='bold')
    
    metrics_path = os.path.join(results_dir, 'performance_metrics.png')
    plt.savefig(metrics_path)
    plt.close()
    
    # Plot class-wise accuracy
    plt.figure(figsize=(8, 6))
    class_acc = [cancer_correct / len(cancer_files), non_cancer_correct / len(non_cancer_files)]
    plt.bar(['Cancer', 'Non-Cancer'], class_acc, color=['#E53935', '#43A047'])
    plt.ylim(0, 1.0)
    plt.title('Class-wise Accuracy')
    plt.ylabel('Accuracy')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels and sample counts
    for i, v in enumerate(class_acc):
        count = cancer_correct if i == 0 else non_cancer_correct
        total = len(cancer_files) if i == 0 else len(non_cancer_files)
        plt.text(i, v + 0.02, f'{v:.4f}\n({count}/{total})', ha='center', fontweight='bold')
    
    class_acc_path = os.path.join(results_dir, 'class_accuracy.png')
    plt.savefig(class_acc_path)
    plt.close()
    
    # Plot confidence distribution
    plt.figure(figsize=(12, 6))
    
    # Split confidences by true class and prediction
    cancer_confidences = [confidences[i] for i in range(len(cancer_files))]
    non_cancer_confidences = [confidences[i] for i in range(len(cancer_files), len(all_files))]
    
    # Separate correct and incorrect predictions
    cancer_correct_conf = [confidences[i] for i in range(len(cancer_files)) if predictions[i] == true_labels[i]]
    cancer_incorrect_conf = [confidences[i] for i in range(len(cancer_files)) if predictions[i] != true_labels[i]]
    non_cancer_correct_conf = [confidences[i] for i in range(len(cancer_files), len(all_files)) if predictions[i] == true_labels[i]]
    non_cancer_incorrect_conf = [confidences[i] for i in range(len(cancer_files), len(all_files)) if predictions[i] != true_labels[i]]
    
    plt.subplot(1, 2, 1)
    plt.hist(cancer_correct_conf, alpha=0.7, bins=10, label='Correct', color='green')
    plt.hist(cancer_incorrect_conf, alpha=0.7, bins=10, label='Incorrect', color='red')
    plt.title('Cancer Class Confidence Distribution')
    plt.xlabel('Confidence (%)')
    plt.ylabel('Number of Images')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.hist(non_cancer_correct_conf, alpha=0.7, bins=10, label='Correct', color='green')
    plt.hist(non_cancer_incorrect_conf, alpha=0.7, bins=10, label='Incorrect', color='red')
    plt.title('Non-Cancer Class Confidence Distribution')
    plt.xlabel('Confidence (%)')
    plt.ylabel('Number of Images')
    plt.legend()
    
    plt.tight_layout()
    conf_dist_path = os.path.join(results_dir, 'confidence_distribution.png')
    plt.savefig(conf_dist_path)
    plt.close()
    
    # ROC curve and AUC
    # Convert predictions to probabilities for cancer class (1 - confidence if predicted non-cancer)
    cancer_probs = []
    for i, pred in enumerate(predictions):
        if pred == 0:  # If predicted cancer
            cancer_probs.append(confidences[i] / 100)
        else:  # If predicted non-cancer
            cancer_probs.append(1 - (confidences[i] / 100))
    
    # Compute ROC curve and AUC
    fpr, tpr, _ = roc_curve([1 - label for label in true_labels], cancer_probs)  # Invert labels since ROC uses positive class
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.grid(linestyle='--', alpha=0.7)
    
    roc_path = os.path.join(results_dir, 'roc_curve.png')
    plt.savefig(roc_path)
    plt.close()
    
    print(f"\nAll plots saved in directory: {results_dir}")
    
    # Save detailed results
    results_file = os.path.join(results_dir, 'model_evaluation_results.txt')
    with open(results_file, 'w') as f:
        f.write("===== ENSEMBLE MODEL EVALUATION RESULTS =====\n")
        f.write(f"Total images tested: {len(all_files)}\n")
        f.write(f"Cancer images: {len(cancer_files)}\n")
        f.write(f"Non-cancer images: {len(non_cancer_files)}\n")
        f.write("\nPerformance Metrics:\n")
        f.write(f"Accuracy: {accuracy:.4f}\n")
        f.write(f"Precision: {precision:.4f}\n")
        f.write(f"Recall: {recall:.4f}\n")
        f.write(f"F1 Score: {f1:.4f}\n")
        f.write(f"\nAverage prediction time per image: {avg_time_per_image:.4f} seconds\n")
        f.write("\nDetailed Results:\n")
        f.write("Image Path | True Label | Predicted Label | Confidence\n")
        for i, img_path in enumerate(all_files):
            f.write(f"{img_path} | {'Cancer' if true_labels[i] == 0 else 'Non-Cancer'} | ")
            f.write(f"{'Cancer' if predictions[i] == 0 else 'Non-Cancer'} | {confidences[i]:.2f}%\n")
    
    print(f"\nDetailed results saved to '{results_file}'")
    
    # Calculate and display additional analysis
    print("\nClass-wise Accuracy:")
    print(f"Cancer class accuracy: {cancer_correct / len(cancer_files):.4f} ({cancer_correct}/{len(cancer_files)})")
    print(f"Non-cancer class accuracy: {non_cancer_correct / len(non_cancer_files):.4f} ({non_cancer_correct}/{len(non_cancer_files)})")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm,
        'cancer_accuracy': cancer_correct / len(cancer_files),
        'non_cancer_accuracy': non_cancer_correct / len(non_cancer_files)
    }

if __name__ == "__main__":
    # Hardcoded test directory path - change this to your main test data directory
    test_dir = r"D:\oral_cancer_detection\testdata"
    
    # Don't use sys.argv to avoid Jupyter notebook argument issues
    print(f"Testing ensemble model on data in: {test_dir}")
    test_ensemble_model(test_dir)

Testing ensemble model on data in: D:\oral_cancer_detection\testdata
Found 89 cancer images and 27 non-cancer images.


Testing images: 100%|██████████| 116/116 [00:48<00:00,  2.39it/s]



===== ENSEMBLE MODEL EVALUATION RESULTS =====
Total images tested: 116
Cancer images: 89
Non-cancer images: 27

Performance Metrics:
Accuracy: 0.9310
Precision: 0.7879
Recall: 0.9630
F1 Score: 0.8667

Average prediction time per image: 0.4187 seconds

All plots saved in directory: model_evaluation_results

Detailed results saved to 'model_evaluation_results\model_evaluation_results.txt'

Class-wise Accuracy:
Cancer class accuracy: 0.9213 (82/89)
Non-cancer class accuracy: 0.9630 (26/27)


<Figure size 1000x800 with 0 Axes>