In [None]:
from test_evaluator import run_evaluation
import torch
from pathlib import Path
from data_preprocessing import DataModule, DatasetConfig, DatasetType, ModelType
from CNN import create_model

def evaluate_model():
    # 1. Setup configurations
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    base_path = Path(r"data\constellation_dataset_1")
    save_dir = Path("evaluation_results")
    
    # 2. Setup data
    data_config = DatasetConfig(base_path)
    data_module = DataModule(
        data_config=data_config,
        model_type=ModelType.CNN,
        batch_size=32,
        num_workers=4
    )
    
    # Get test dataloader
    test_loader = data_module.get_dataloader(DatasetType.TEST)
    
    # Get class names
    class_names = data_module.datasets[DatasetType.TRAIN].class_columns
    
    # 3. Load model
    model = create_model(
        model_type='cnn',
        num_classes=16,
        pretrained=False,  # Set to False since we're loading weights
        backbone='resnet50',
        dropout_rate=0.5
    )
    
    # Load the trained weights
    checkpoint_path = Path(r"src\analytics\image_analytics\models\CNN\logs\cnn_constellation_classification_20241120_145519\best_model.pt")
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Successfully loaded model from {checkpoint_path}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    model = model.to(device)
    
    # 4. Run evaluation
    results = run_evaluation(
        model=model,
        test_loader=test_loader,
        device=device,
        class_names=class_names,
        save_dir=save_dir
    )
    
    # 5. Print detailed metrics summary
    print("\nDetailed Evaluation Results:")
    print("\nOverall Metrics:")
    print("-" * 50)
    for metric, value in results['overall_metrics'].items():
        print(f"{metric:25s}: {value:.4f}")
    
    print("\nPer-Class Metrics:")
    print("-" * 50)
    for class_name, metrics in results['per_class_metrics'].items():
        print(f"\n{class_name}:")
        for metric, value in metrics.items():
            if metric != 'confusion_matrix':
                print(f"  {metric:25s}: {value:.4f}")
    
    print("\nOptimal Thresholds:")
    print("-" * 50)
    for class_name, threshold in results['thresholds'].items():
        print(f"{class_name:25s}: {threshold:.4f}")
    
    print(f"\nDetailed results have been saved to {save_dir}")
    print("The following files have been generated:")
    print("1. evaluation_results.json - Complete results in JSON format")
    print("2. evaluation_report.md - Detailed report in Markdown format")
    print("3. evaluation_results.xlsx - Excel report with multiple sheets")
    print("4. roc_curves.png - ROC curves for all classes")
    print("5. precision_recall_curves.png - Precision-Recall curves")
    print("6. confusion_matrices.png - Confusion matrices for all classes")
    print("7. score_distributions.png - Score distributions for all classes")

evaluate_model()

  return torch.FloatTensor(weights)
  checkpoint = torch.load(checkpoint_path, map_location=device)


Successfully loaded model from E:\University\CU_Classes\Year-1\Fall 2024\CSCI_5502_Data Mining\Milestone_project\stellar_mapping\src\analytics\image_analytics\models\CNN\logs\cnn_constellation_classification_20241120_145519\best_model.pt


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Evaluation completed!

Overall Metrics:
exact_match: 0.2650
hamming_loss: 0.0745
mean_average_precision: 0.8321
subset_accuracy: 0.2650
micro_precision: 0.9801
micro_recall: 0.6213
micro_f1: 0.7605
macro_precision: 0.9188
macro_recall: 0.7364
macro_f1: 0.7812

Detailed Evaluation Results:

Overall Metrics:
--------------------------------------------------
exact_match              : 0.2650
hamming_loss             : 0.0745
mean_average_precision   : 0.8321
subset_accuracy          : 0.2650
micro_precision          : 0.9801
micro_recall             : 0.6213
micro_f1                 : 0.7605
macro_precision          : 0.9188
macro_recall             : 0.7364
macro_f1                 : 0.7812

Per-Class Metrics:
--------------------------------------------------

 aquila:
  precision                : 0.9667
  recall                   : 0.9355
  f1_score                 : 0.9508
  roc_auc                  : 0.9995
  average_precision        : 0.9969
  support                  : 31.0000

 