```sh
python scripts/run_evaluation.py     
    --config configs/evaluation_config.yaml     
    --data-path /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm 
    --weights-path /sps/lsst/groups/transients/HSC/fouchez/raphael/training/simple_run     
    --output-dir saved/test_eval     
    --interpretability 
    --model-hash "6d5bb4aa" 
```

# Model Evaluation Example Notebook

This notebook demonstrates how to use the ML4transients evaluation framework to comprehensively evaluate trained models. It covers:

1. **Basic Metrics Evaluation** - Computing standard classification metrics
2. **Interactive Dashboards** - Creating Bokeh-based visualization dashboards  
3. **Interpretability Analysis** - UMAP-based model interpretability with clustering
4. **Comparative Analysis** - Comparing multiple models

## Prerequisites

- Trained model weights directory with `config.yaml` and `model_best.pth`
- Dataset with cutouts and features
- Optional: Existing inference results

In [43]:
# Import required libraries
import sys
import numpy as np
import pandas as pd
from pathlib import Path
import yaml
import torch
from torch.utils.data import DataLoader

# Add ML4transients to path
sys.path.append('/sps/lsst/users/rbonnetguerrini/ML4transients/src')

from ML4transients.data_access import DatasetLoader
from ML4transients.evaluation.metrics import EvaluationMetrics
from ML4transients.evaluation.visualizations import (
    create_evaluation_dashboard, 
    create_interpretability_dashboard,
    BokehEvaluationPlots,
    UMAPVisualizer
)
from ML4transients.evaluation.interpretability import UMAPInterpreter
from ML4transients.training.pytorch_dataset import PytorchDataset

# Configure display settings
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported successfully")

✓ Libraries imported successfully


## 1. Configuration and Data Loading

First, let's set up paths and load our dataset. Modify these paths according to your setup.

In [44]:
# Configuration - MODIFY THESE PATHS
DATA_PATH = "/sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm"  # Path to directory containing cutouts/ and features/
WEIGHTS_PATH = "/sps/lsst/groups/transients/HSC/fouchez/raphael/training/simple_run"  # Path to trained model directory
OUTPUT_DIR = "../saved/test_eval/notebook"

# Optional: Model hash for existing inference results
MODEL_HASH = "6d5bb4aa"  

# Create output directory
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Data path: {DATA_PATH}")
print(f"Model weights: {WEIGHTS_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

Data path: /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm
Model weights: /sps/lsst/groups/transients/HSC/fouchez/raphael/training/simple_run
Output directory: ../saved/test_eval/notebook


In [45]:
# Load dataset
print("Loading dataset...")
dataset_loader = DatasetLoader(DATA_PATH)

# Display dataset information
print(f"\n{dataset_loader}")
print(f"Available visits: {dataset_loader.visits[:10]}..." if len(dataset_loader.visits) > 10 else f"Available visits: {dataset_loader.visits}")

# Check for existing inference results
if MODEL_HASH:
    print("\nChecking for existing inference results...")
    dataset_loader.sync_inference_registry()
    dataset_loader.list_available_inference()

Loading dataset...
Loaded inference registry from /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm/inference/inference_registry.json

DatasetLoader(40 visits, 1 paths)
  Cutouts: 57905 across 40 visits
  Features: 57905 across 40 visits
Available visits: [322, 346, 358, 1178, 1184, 1204, 1206, 1214, 1220, 1242]...

Checking for existing inference results...
Syncing inference registry for /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm...
Saved inference registry to /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm/inference/inference_registry.json
Registry sync complete: added 0, removed 0 entries
Available inference files in /sps/lsst/groups/transients/HSC/fouchez/raphael/data/rc2_norm:
  Visit 1214:
    Model hash: 6d5bb4aa
      File: visit_1214_inference_6d5bb4aa.h5
  Visit 29336:
    Model hash: 6d5bb4aa
      File: visit_29336_inference_6d5bb4aa.h5
  Visit 22662:
    Model hash: 6d5bb4aa
      File: visit_22662_inference_6d5bb4aa.h5
  Visit 

## 2. Generate or Load Inference Results

We can either run new inference or load existing results. For demonstration, we'll show both approaches.

In [46]:
# Option 1: Check for existing inference results
"""
print("Checking for existing inference results...")
inference_results = dataset_loader.check_or_run_inference(WEIGHTS_PATH)
"""  
# Option 2: load from the HASH
print("Loading existing inference results by model hash...")
inference_results = {}

for visit in dataset_loader.visits:
    inference_loader = dataset_loader.get_inference_loader(
        visit=visit,
        model_hash=MODEL_HASH
    )
    if inference_loader and inference_loader.has_inference_results():
        inference_results[visit] = inference_loader
        
print(f"Loaded inference results for {len(inference_results)} visits")
    


# Verify we have results
if not inference_results:
    raise ValueError("No inference results available. Please run inference first.")
    
print(f"\nSuccessfully loaded inference results for {len(inference_results)} visits")

Loading existing inference results by model hash...
Loaded inference loader for visit 322, model 6d5bb4aa
Loaded inference loader for visit 346, model 6d5bb4aa
Loaded inference loader for visit 358, model 6d5bb4aa
Loaded inference loader for visit 1178, model 6d5bb4aa
Loaded inference loader for visit 1184, model 6d5bb4aa
Loaded inference loader for visit 1204, model 6d5bb4aa
Loaded inference loader for visit 1206, model 6d5bb4aa
Loaded inference loader for visit 1214, model 6d5bb4aa
Loaded inference loader for visit 1220, model 6d5bb4aa
Loaded inference loader for visit 1242, model 6d5bb4aa
Loaded inference loader for visit 1248, model 6d5bb4aa
Loaded inference loader for visit 11690, model 6d5bb4aa
Loaded inference loader for visit 11694, model 6d5bb4aa
Loaded inference loader for visit 11696, model 6d5bb4aa
Loaded inference loader for visit 11698, model 6d5bb4aa
Loaded inference loader for visit 11704, model 6d5bb4aa
Loaded inference loader for visit 11710, model 6d5bb4aa
Loaded inf

In [47]:
# Aggregate results across all visits
print("Aggregating inference results...")

all_predictions = []
all_labels = []
all_source_ids = []

for visit, inference_loader in inference_results.items():
    all_predictions.append(inference_loader.predictions)
    all_labels.append(inference_loader.labels)
    all_source_ids.append(inference_loader.ids)

# Concatenate results
predictions = np.concatenate(all_predictions)
labels = np.concatenate(all_labels)
source_ids = np.concatenate(all_source_ids)

print(f"Total samples: {len(predictions):,}")
print(f"Positive samples: {np.sum(labels):,} ({np.mean(labels)*100:.1f}%)")
print(f"Predicted positive: {np.sum(predictions):,} ({np.mean(predictions)*100:.1f}%)")

Aggregating inference results...
Total samples: 57,328
Positive samples: 12,991.0 (22.7%)
Predicted positive: 13,135.0 (22.9%)


## 3. Basic Metrics Evaluation

Let's compute and display the standard classification metrics.

In [48]:
# Create metrics object
print("Computing evaluation metrics...")
metrics = EvaluationMetrics(predictions, labels)

# Display summary
summary = metrics.summary()
print("\n" + "="*60)
print("MODEL EVALUATION SUMMARY")
print("="*60)

for metric_name, value in summary.items():
    print(f"{metric_name.replace('_', ' ').title():.<30} {value:.4f}")

print("="*60)

Computing evaluation metrics...

MODEL EVALUATION SUMMARY
Accuracy...................... 0.9857
Precision..................... 0.9632
Recall........................ 0.9739
F1 Score...................... 0.9685
Specificity................... 0.9891


In [49]:
# Detailed confusion matrix statistics
print("\nConfusion Matrix Analysis:")
print("-" * 40)

cm_stats = metrics.get_confusion_matrix_stats()
for key, value in cm_stats.items():
    if isinstance(value, (int, np.integer)):
        print(f"{key.replace('_', ' ').title():.<30} {value:,}")
    else:
        print(f"{key.replace('_', ' ').title():.<30} {value:.4f}")


Confusion Matrix Analysis:
----------------------------------------
True Positive................. 12,652
True Negative................. 43,854
False Positive................ 483
False Negative................ 339
True Positive Rate............ 0.9739
True Negative Rate............ 0.9891
Positive Predictive Value..... 0.9632
Negative Predictive Value..... 0.9923
False Positive Rate........... 0.0109
False Negative Rate........... 0.0261
Total Samples................. 57,328


## 4. Interactive Evaluation Dashboard

Create an interactive Bokeh dashboard with all evaluation plots.

In [50]:
# Create evaluation dashboard
print("Creating evaluation dashboard...")

model_name = Path(WEIGHTS_PATH).name if WEIGHTS_PATH else f"Model_{MODEL_HASH}"
dashboard_path = output_dir / "evaluation_dashboard.html"

evaluation_dashboard = create_evaluation_dashboard(
    metrics,
    output_path=dashboard_path,
    title=f"Model Evaluation - {model_name}"
)

print(f" Evaluation dashboard saved to: {dashboard_path}")


Creating evaluation dashboard...
Dashboard saved to ../saved/test_eval/notebook/evaluation_dashboard.html
 Evaluation dashboard saved to: ../saved/test_eval/notebook/evaluation_dashboard.html
Dashboard saved to ../saved/test_eval/notebook/evaluation_dashboard.html
 Evaluation dashboard saved to: ../saved/test_eval/notebook/evaluation_dashboard.html


## 5. Interpretability Analysis with UMAP

Now let's perform UMAP-based interpretability analysis to understand what the model has learned.

In [51]:
# Configuration for interpretability analysis
INTERPRETABILITY_CONFIG = {
    'max_samples': 1000,  # Limit samples for visualization performance
    'layer_name': 'fc1',  # Which layer to extract features from
    'umap': {
        'n_neighbors': 15,
        'min_dist': 0.1,
        'n_components': 2,
        'random_state': 42
    },
    'clustering': {
        'n_components_range': [5, 15]
    }
}



In [52]:
# Create dataset for interpretability analysis
print("Creating evaluation dataset for interpretability...")

# Select subset of visits for performance (optional)
selected_visits = dataset_loader.visits[:3]  # Use first 3 visits for demo
print(f"Using visits: {selected_visits}")

eval_dataset = PytorchDataset.create_inference_dataset(
    dataset_loader, 
    visits=selected_visits
)

eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    pin_memory=False
)

print(f"Created evaluation dataset with {len(eval_dataset)} samples")

Creating evaluation dataset for interpretability...
Using visits: [322, 346, 358]
Building inference dataset index across visits...


Created inference dataset with 1805 samples across 3 visits
Created lazy dataset with 1805 samples
Created evaluation dataset with 1805 samples


In [53]:
# Initialize UMAP interpreter
if not WEIGHTS_PATH:
    print("Warning: UMAP interpretability requires WEIGHTS_PATH, skipping...")
else:
    print("Initializing UMAP interpreter...")
    interpreter = UMAPInterpreter(WEIGHTS_PATH)
    
    # Extract features from specified layer
    print(f"Extracting features from layer: {INTERPRETABILITY_CONFIG['layer_name']}")
    features = interpreter.extract_features(
        eval_dataloader, 
        layer_name=INTERPRETABILITY_CONFIG['layer_name']
    )
    
    print(f"Extracted features shape: {features.shape}")

Initializing UMAP interpreter...
TensorBoard logging to: runs/gpu_training_845962
Run 'tensorboard --logdir=runs' to view logs
Extracting features from layer: fc1
Extracting features from layer: fc1
Processing batch 1/29
Processing batch 11/29
Processing batch 11/29
Processing batch 21/29
Extracted features shape: (1805, 128)
Extracted features shape: (1805, 128)
Processing batch 21/29
Extracted features shape: (1805, 128)
Extracted features shape: (1805, 128)


In [54]:
# Fit UMAP embedding
if WEIGHTS_PATH:
    print("Fitting UMAP embedding...")
    umap_embedding = interpreter.fit_umap(**INTERPRETABILITY_CONFIG['umap'])
    
    print(f"UMAP embedding shape: {umap_embedding.shape}")
    print(f"  X range: [{umap_embedding[:, 0].min():.2f}, {umap_embedding[:, 0].max():.2f}]")
    print(f"  Y range: [{umap_embedding[:, 1].min():.2f}, {umap_embedding[:, 1].max():.2f}]")

Fitting UMAP embedding...
UMAP embedding shape: (1805, 2)
  X range: [-4.78, 10.66]
  Y range: [-1.57, 11.39]
UMAP embedding shape: (1805, 2)
  X range: [-4.78, 10.66]
  Y range: [-1.57, 11.39]


In [55]:
# Create interpretability dataframe
if WEIGHTS_PATH:
    print("Creating interpretability dataframe...")
    
    # Get predictions for the evaluation dataset
    eval_predictions = predictions[:len(eval_dataset)]
    eval_labels = labels[:len(eval_dataset)]
    
    # Optional: Add additional features (e.g., SNR if available)
    additional_features = {}
    # Example: additional_features['SNR'] = your_snr_values
    
    # Create dataframe with UMAP coordinates and images
    interp_df = interpreter.create_interpretability_dataframe(
        eval_predictions,
        eval_labels,
        eval_dataloader,
        additional_features=additional_features
    )
    
    # Add clustering
    print("Performing clustering analysis...")
    interp_df = interpreter.add_clustering_to_dataframe(
        interp_df, 
        tuple(INTERPRETABILITY_CONFIG['clustering']['n_components_range'])
    )
    
    print(f"✓ Created interpretability dataframe with {len(interp_df)} samples")
    print(f"  Columns: {list(interp_df.columns)}")
    print(f"  Clusters found: {sorted(interp_df['cluster'].unique())}")

Creating interpretability dataframe...
Creating embeddable images for hover tooltips...
Processed 640 images...
Processed 640 images...
Processed 1280 images...
Processed 1280 images...
Created 1805 embeddable images
Performing clustering analysis...
Created 1805 embeddable images
Performing clustering analysis...
✓ Created interpretability dataframe with 1805 samples
  Columns: ['umap_x', 'umap_y', 'prediction', 'true_label', 'correct', 'class_type', 'image', 'cluster']
  Clusters found: ['0', '1', '2', '3', '4', '5', '6']
✓ Created interpretability dataframe with 1805 samples
  Columns: ['umap_x', 'umap_y', 'prediction', 'true_label', 'correct', 'class_type', 'image', 'cluster']
  Clusters found: ['0', '1', '2', '3', '4', '5', '6']


## 6. Interactive Interpretability Dashboard

Create an interactive dashboard for exploring the UMAP visualization.

In [56]:
# Create interpretability dashboard
if WEIGHTS_PATH:
    print("Creating interpretability dashboard...")
    
    interp_dashboard_path = output_dir / "interpretability_dashboard.html"
    
    interp_dashboard = create_interpretability_dashboard(
        interpreter,
        eval_dataloader,
        eval_predictions,
        eval_labels,
        output_path=interp_dashboard_path,
        additional_features=additional_features,
        config={'interpretability': INTERPRETABILITY_CONFIG}
    )
    
    print(f"Interpretability dashboard saved to: {interp_dashboard_path}")


Creating interpretability dashboard...
Extracting features...
Extracting features from layer: fc1
Processing batch 1/29
Processing batch 11/29
Processing batch 21/29
Extracted features shape: (1805, 128)
Fitting UMAP...
Extracted features shape: (1805, 128)
Fitting UMAP...
Creating visualization dataframe (max 1000 samples)...
Creating embeddable images for hover tooltips...
Creating visualization dataframe (max 1000 samples)...
Creating embeddable images for hover tooltips...
Processed 640 images...
Processed 1280 images...
Processed 640 images...
Processed 1280 images...
Created 1805 embeddable images
Performing clustering...
Created 1805 embeddable images
Performing clustering...
Created visualization dataframe with 1805 samples using fc1 features
Interpretability dashboard saved to ../saved/test_eval/notebook/interpretability_dashboard.html
Interpretability dashboard saved to: ../saved/test_eval/notebook/interpretability_dashboard.html
Created visualization dataframe with 1805 samp

## 8. Save Analysis Results

Save all results for future reference and sharing.

In [57]:
# Save analysis results
results_summary = {
    'model_info': {
        'weights_path': str(WEIGHTS_PATH) if WEIGHTS_PATH else None,
        'model_hash': MODEL_HASH,
        'data_path': str(DATA_PATH),
        'visits_evaluated': list(inference_results.keys()),
        'total_samples': int(len(predictions))
    },
    'metrics': {k: float(v) for k, v in summary.items()},
    'confusion_matrix_stats': {k: int(v) if isinstance(v, (int, np.integer)) else float(v) 
                              for k, v in cm_stats.items()},
}

if WEIGHTS_PATH:
    results_summary['interpretability'] = {
        'config': INTERPRETABILITY_CONFIG,
        'umap_embedding_shape': list(umap_embedding.shape),
        'n_clusters': int(n_clusters),
        'cluster_performance': perf_df.to_dict('records') if 'perf_df' in locals() else None
    }

# Save to YAML file
results_file = output_dir / "evaluation_summary.yaml"
with open(results_file, 'w') as f:
    yaml.dump(results_summary, f, default_flow_style=False)

print(f"Analysis results saved to: {results_file}")

# Save interpretability dataframe
if WEIGHTS_PATH:
    interp_csv = output_dir / "interpretability_data.csv"
    interp_df.drop(columns=['image']).to_csv(interp_csv, index=False)  # Drop image column for CSV
    print(f"Interpretability data saved to: {interp_csv}")

✓ Analysis results saved to: ../saved/test_eval/notebook/evaluation_summary.yaml
✓ Interpretability data saved to: ../saved/test_eval/notebook/interpretability_data.csv
