# LSMI Evaluation Notebook

This notebook evaluates the IlluminantCNN model on the balanced LSMI test set.
It performs the following:
1.  Loads the pre-trained model.
2.  Runs inference on the test images.
3.  Generates Grad-CAM heatmaps for each illuminant cluster.
4.  Calculates metrics (IOU, DICE, P-MAE) against the Ground Truth masks.
5.  Visualizes the results.

In [1]:
import os
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pandas as pd

# Import Models
from illuminant_estimation.models.cnn import IlluminantCNN, IllumiCam3, ConfidenceWeightedCNN, ColorConstancyCNN

# Import CAM Methods
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Import Test Package Utils
import sys
sys.path.append("LSMI_Test_Package")
from lsmi_utils import process_raw_image, load_mask, get_cluster_names

# Constants
TEST_PACKAGE_DIR = "LSMI_Test_Package"
IMAGES_DIR = os.path.join(TEST_PACKAGE_DIR, "images")
MASKS_DIR = os.path.join(TEST_PACKAGE_DIR, "masks")
MODEL_PATH = "best_illuminant_cnn_val_8084.pth" # Update with your model path
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLUSTER_NAMES = get_cluster_names()

print(f"Using device: {DEVICE}")




Using device: cuda


## 1. Load Model
Load the pre-trained IlluminantCNN model.

In [None]:
def load_model(model_name, model_path):
    print(f"Loading model architecture: {model_name}")
    
    if model_name == "IlluminantCNN":
        model = IlluminantCNN(num_classes=5)
    elif model_name == "IllumiCam3":
        model = IllumiCam3(num_classes=5)
    elif model_name == "ConfidenceWeightedCNN":
        model = ConfidenceWeightedCNN(num_classes=5)
    elif model_name == "ColorConstancyCNN":
        model = ColorConstancyCNN(K=5, pretrained=False) # Pretrained=False to avoid loading ImageNet weights if we load state_dict
    else:
        raise ValueError(f"Unknown model name: {model_name}")
        
    model = model.to(DEVICE)
    
    if os.path.exists(model_path):
        try:
            # Try loading state dict
            state_dict = torch.load(model_path, map_location=DEVICE)
            model.load_state_dict(state_dict, strict=False) # strict=False to handle potential minor mismatches (e.g. ImageNet weights)
            print(f"Loaded weights from {model_path}")
        except Exception as e:
            print(f"Error loading weights: {e}")
            print("Using random/initialized weights.")
    else:
        print(f"Model file not found at {model_path}. Using random/initialized weights.")
        
    model.eval()
    return model

# Configuration
MODEL_NAME = "IlluminantCNN" # Options: IlluminantCNN, IllumiCam3, ConfidenceWeightedCNN, ColorConstancyCNN
CAM_METHOD_NAME = "GradCAMPlusPlus" # Options: GradCAM, GradCAMPlusPlus, ScoreCAM

model = load_model(MODEL_NAME, MODEL_PATH)


## 2. Define Metrics
We define the metrics for evaluating the predicted masks (Grad-CAMs) against the Ground Truth masks.

- **IOU (Intersection over Union)**: `TP / (TP + FP + FN)`
- **DICE Score**: `2 * TP / (2 * TP + FP + FN)`
- **P-MAE (Pixel-level Mean Absolute Error)**: `mean(|GT - Pred|)`

Note: GT masks are continuous probabilities (0-1). Grad-CAMs are also normalized to 0-1. For IOU and DICE, we threshold them (e.g., > 0.5).

In [None]:
def calculate_metrics(gt_mask, pred_mask, gt_threshold=0.5, pred_threshold=0.5):
    """
    Calculates IOU, DICE, and MAE.
    IOU and DICE are calculated on binary masks created by thresholding GT and Pred.
    MAE is calculated on continuous values.
    """
    
    # MAE (on continuous values)
    mae = np.mean(np.abs(gt_mask - pred_mask))
    
    # Threshold for IOU/DICE
    gt_bin = (gt_mask > gt_threshold).astype(bool)
    pred_bin = (pred_mask > pred_threshold).astype(bool)
    
    intersection = np.logical_and(gt_bin, pred_bin).sum()
    union = np.logical_or(gt_bin, pred_bin).sum()
    
    if union == 0:
        iou = 1.0 if intersection == 0 else 0.0
    else:
        iou = intersection / union
        
    dice_denom = gt_bin.sum() + pred_bin.sum()
    if dice_denom == 0:
        dice = 1.0 if intersection == 0 else 0.0
    else:
        dice = 2 * intersection / dice_denom
        
    return iou, dice, mae


## 3. Evaluation Loop
Run inference on all test images, generate Grad-CAMs for all clusters, and calculate metrics.

In [None]:
# Setup CAM
def get_cam(model, model_name, method_name):
    # Determine target layers based on model architecture
    if model_name == "IlluminantCNN":
        target_layers = [model.conv5]
    elif model_name == "IllumiCam3":
        target_layers = [model.conv5]
    elif model_name == "ConfidenceWeightedCNN":
        target_layers = [model.conv5] # Last shared conv layer
    elif model_name == "ColorConstancyCNN":
        # Target the last convolutional layer of AlexNet features
        # model.features is Sequential. Index 10 is Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        target_layers = [model.features[10]] 
    else:
        target_layers = [list(model.modules())[-1]] # Fallback

    if method_name == "GradCAM":
        return GradCAM(model=model, target_layers=target_layers)
    elif method_name == "GradCAMPlusPlus":
        return GradCAMPlusPlus(model=model, target_layers=target_layers)
    elif method_name == "ScoreCAM":
        return ScoreCAM(model=model, target_layers=target_layers)
    else:
        raise ValueError(f"Unknown CAM method: {method_name}")

cam = get_cam(model, MODEL_NAME, CAM_METHOD_NAME)
print(f"Using CAM Method: {CAM_METHOD_NAME} on {MODEL_NAME}")

# Transforms for model input
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Configurable Thresholds
GT_THRESHOLD = 0.5
PRED_THRESHOLD = 0.5

print(f"Using GT Threshold: {GT_THRESHOLD}")
print(f"Using Pred Threshold: {PRED_THRESHOLD}")

results = []

# Get list of test scenes
test_scenes = [f.replace(".nef", "") for f in os.listdir(IMAGES_DIR) if f.endswith(".nef")]
print(f"Evaluating on {len(test_scenes)} scenes...")

for scene_id in tqdm(test_scenes):
    try:
        # Load Image
        img_path = os.path.join(IMAGES_DIR, f"{scene_id}.nef")
        
        # Load Raw for Model Input (Linear RGB)
        img_raw = process_raw_image(img_path, srgb=False)
        
        # Load sRGB for Visualization/Mask Reference
        img_rgb = process_raw_image(img_path, srgb=True)
        
        # Prepare for Model (Use Raw)
        img_pil = Image.fromarray(img_raw)
        input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
        
        # Load GT Mask
        mask_path = os.path.join(MASKS_DIR, f"{scene_id}_mask.npy")
        gt_mask = load_mask(mask_path, target_shape=img_rgb.shape)
        
        # Generate CAM for each cluster
        scene_metrics = {'scene': scene_id}
        
        for i, cluster_name in enumerate(CLUSTER_NAMES):
            # Get GT channel
            gt_channel = gt_mask[:, :, i]
            
            # Skip if GT is empty
            if gt_channel.max() == 0:
                continue
            
            targets = [ClassifierOutputTarget(i)]
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
            
            # Resize CAM to original image size
            pred_mask = cv2.resize(grayscale_cam, (img_rgb.shape[1], img_rgb.shape[0]))
            
            # Calculate Metrics
            iou, dice, mae = calculate_metrics(gt_channel, pred_mask, gt_threshold=GT_THRESHOLD, pred_threshold=PRED_THRESHOLD)
            
            scene_metrics[f'{cluster_name}_IOU'] = iou
            scene_metrics[f'{cluster_name}_DICE'] = dice
            scene_metrics[f'{cluster_name}_MAE'] = mae
            
        results.append(scene_metrics)
        
    except Exception as e:
        print(f"Error processing {scene_id}: {e}")

df_results = pd.DataFrame(results)
print("Evaluation Complete.")


## 4. Results Summary
Average metrics across the test set.

In [None]:
# Calculate averages
summary = {}
for metric in ['IOU', 'DICE', 'MAE']:
    for cluster in CLUSTER_NAMES:
        col = f'{cluster}_{metric}'
        if col in df_results.columns:
            summary[col] = df_results[col].mean()

print("Average Metrics:")
for cluster in CLUSTER_NAMES:
    print(f"\nCluster: {cluster}")
    print(f"  IOU:  {summary.get(f'{cluster}_IOU', 0):.4f}")
    print(f"  DICE: {summary.get(f'{cluster}_DICE', 0):.4f}")
    print(f"  MAE:  {summary.get(f'{cluster}_MAE', 0):.4f}")

# Overall Average
print("\nOverall Averages:")
print(f"  mIOU: {np.mean([summary[f'{c}_IOU'] for c in CLUSTER_NAMES]):.4f}")
print(f"  mDICE: {np.mean([summary[f'{c}_DICE'] for c in CLUSTER_NAMES]):.4f}")
print(f"  mMAE: {np.mean([summary[f'{c}_MAE'] for c in CLUSTER_NAMES]):.4f}")

In [None]:
# Comprehensive Evaluation Matrix
# Runs trials on all combinations (4 Models x 3 CAMs)
# Weights: Specific weights for each model
# Thresholds: GT=0, Pred=0.1

# Configuration
MODELS = {
    "IlluminantCNN": "best_illuminant_cnn_val_8084.pth",
    "IllumiCam3": "illumicam3.pth",
    "ColorConstancyCNN": "best_paper_model.pth",
    "ConfidenceWeightedCNN": "best_illuminant_cnn_confidence.pth"
}

CAM_METHODS = ["GradCAM", "GradCAMPlusPlus", "ScoreCAM"]

# Thresholds
MATRIX_GT_THRESHOLD = 0.0
MATRIX_PRED_THRESHOLD = 0.1

matrix_results = []

print("Starting Comprehensive Evaluation Matrix...")
print(f"GT Threshold: {MATRIX_GT_THRESHOLD}, Pred Threshold: {MATRIX_PRED_THRESHOLD}")

for model_name, model_path in MODELS.items():
    print(f"\nEvaluating Model: {model_name}")
    
    # Load Model
    try:
        # Note: load_model and get_cam must be defined in previous cells
        model = load_model(model_name, model_path)
    except Exception as e:
        print(f"Failed to load {model_name}: {e}")
        continue
        
    for cam_method in CAM_METHODS:
        print(f"  Using CAM: {cam_method}")
        
        try:
            cam = get_cam(model, model_name, cam_method)
        except Exception as e:
            print(f"  Failed to init {cam_method}: {e}")
            continue
            
        # Accumulate metrics for this combination
        combo_metrics = {
            "Model": model_name,
            "CAM": cam_method,
            "mIOU": [],
            "mDICE": [],
            "mMAE": []
        }
        
        # Run Inference on all scenes
        for scene_id in tqdm(test_scenes, desc=f"{model_name}-{cam_method}", leave=False):
            try:
                # Load Data (Raw for Input, sRGB for shape)
                img_path = os.path.join(IMAGES_DIR, f"{scene_id}.nef")
                img_raw = process_raw_image(img_path, srgb=False)
                img_rgb = process_raw_image(img_path, srgb=True) # For shape/mask target
                
                img_pil = Image.fromarray(img_raw)
                input_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
                
                mask_path = os.path.join(MASKS_DIR, f"{scene_id}_mask.npy")
                gt_mask = load_mask(mask_path, target_shape=img_rgb.shape)
                
                scene_ious = []
                scene_dices = []
                scene_maes = []
                
                for i, cluster_name in enumerate(CLUSTER_NAMES):
                    gt_channel = gt_mask[:, :, i]
                    if gt_channel.max() == 0: continue
                    
                    targets = [ClassifierOutputTarget(i)]
                    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
                    pred_mask = cv2.resize(grayscale_cam, (img_rgb.shape[1], img_rgb.shape[0]))
                    
                    # Calculate Metrics with Matrix Thresholds
                    iou, dice, mae = calculate_metrics(gt_channel, pred_mask, gt_threshold=MATRIX_GT_THRESHOLD, pred_threshold=MATRIX_PRED_THRESHOLD)
                    
                    scene_ious.append(iou)
                    scene_dices.append(dice)
                    scene_maes.append(mae)
                
                if scene_ious:
                    combo_metrics["mIOU"].append(np.mean(scene_ious))
                    combo_metrics["mDICE"].append(np.mean(scene_dices))
                    combo_metrics["mMAE"].append(np.mean(scene_maes))
                    
            except Exception as e:
                # print(f"Error on {scene_id}: {e}")
                pass 
        
        # Average over all scenes
        res = {
            "Model": model_name,
            "CAM": cam_method,
            "mIOU": np.mean(combo_metrics["mIOU"]) if combo_metrics["mIOU"] else 0,
            "mDICE": np.mean(combo_metrics["mDICE"]) if combo_metrics["mDICE"] else 0,
            "mMAE": np.mean(combo_metrics["mMAE"]) if combo_metrics["mMAE"] else 0
        }
        matrix_results.append(res)
        print(f"    -> mIOU: {res['mIOU']:.4f}, mDICE: {res['mDICE']:.4f}, mMAE: {res['mMAE']:.4f}")

# Display Results
df_matrix = pd.DataFrame(matrix_results)
print("\nEvaluation Matrix Results:")
try:
    display(df_matrix)
except NameError:
    print(df_matrix)
