# Model Evaluation

Evaluate segmentation model performance on validation and test sets.

Uses torchmetrics for metrics calculation (matches W&B/training metrics).

In [None]:
import importlib
import re
from pathlib import Path

import albumentations as A
import numpy as np
import pandas as pd
import torch
import torchmetrics
import yaml
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.data import NpzSegmentationDataset

## Configuration

In [None]:
# Paths - UPDATE THESE
CONFIG_PATH = Path("../configs/seagrass-rgb/unetpp_resnet34_minimal_aug.yaml")
CKPT_PATH = Path("../checkpoints/unetpp_resnet34_minimal_aug/last.ckpt")  # or download from W&B

# Inference settings
BATCH_SIZE = 16
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

## Load Model

In [None]:
# Load config
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)

# Extract model and data settings
model_config = config["model"]
data_config = config["data"]["init_args"]

num_classes = model_config["init_args"]["num_classes"]
ignore_index = model_config["init_args"].get("ignore_index", -100)
class_names = model_config["init_args"].get("class_names", ["bg", "target"])

print(f"Classes: {class_names}")
print(f"Ignore index: {ignore_index}")

In [None]:
# Load model
class_path = model_config["class_path"]
init_args = model_config["init_args"]

module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
model_class = getattr(module, class_name)

model = model_class(**init_args)

# Load checkpoint
checkpoint = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
model.to(DEVICE)

print(f"Model loaded from {CKPT_PATH}")

## Load Datasets

In [None]:
# Get transforms from config
test_transforms = A.from_dict(data_config["test_transforms"])

# Dataset that also returns filename
class EvalDataset(NpzSegmentationDataset):
    def __getitem__(self, idx):
        chip_path = self.chips[idx]
        data = np.load(chip_path)
        image, label = data["image"], data["label"]
        
        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=label)
            image, label = augmented["image"], augmented["mask"]
        
        return image, label, chip_path.name

# Load datasets
val_dataset = EvalDataset(data_config["val_chip_dir"], transforms=test_transforms)
test_dataset = EvalDataset(data_config["test_chip_dir"], transforms=test_transforms)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

## Setup Metrics

Using torchmetrics to match training metrics exactly.

In [None]:
def create_metrics(num_classes, ignore_index, device):
    """Create torchmetrics matching training setup."""
    task = "binary" if num_classes == 1 else "multiclass"
    
    return {
        "iou": torchmetrics.JaccardIndex(
            task=task, num_classes=num_classes, ignore_index=ignore_index, average="none"
        ).to(device),
        "precision": torchmetrics.Precision(
            task=task, num_classes=num_classes, ignore_index=ignore_index, average="none"
        ).to(device),
        "recall": torchmetrics.Recall(
            task=task, num_classes=num_classes, ignore_index=ignore_index, average="none"
        ).to(device),
        "f1": torchmetrics.F1Score(
            task=task, num_classes=num_classes, ignore_index=ignore_index, average="none"
        ).to(device),
    }

def reset_metrics(metrics):
    for m in metrics.values():
        m.reset()

## Run Inference

In [None]:
def extract_ortho(filename):
    """Extract ortho name from chip filename (e.g., 'calmus_u0421_x123_y456.npz' -> 'calmus_u0421')."""
    match = re.match(r"^(.+?)_x\d+_y\d+\.npz$", filename)
    return match.group(1) if match else filename.replace(".npz", "")


@torch.no_grad()
def evaluate(model, dataloader, metrics, device, num_classes):
    """Run evaluation and collect per-sample results."""
    model.eval()
    reset_metrics(metrics)
    
    # For per-sample metrics
    per_sample_iou = torchmetrics.JaccardIndex(
        task="multiclass", num_classes=num_classes, ignore_index=ignore_index, average="none"
    ).to(device)
    
    results = []
    
    for images, labels, filenames in tqdm(dataloader, desc="Evaluating"):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        logits = model(images)
        
        # Get predictions
        if logits.shape[1] == 1:
            probs = torch.sigmoid(logits).squeeze(1)
            preds = (probs > 0.5).long()
        else:
            probs = torch.softmax(logits, dim=1)
            preds = probs.argmax(dim=1)
        
        # Update global metrics
        for m in metrics.values():
            m.update(preds, labels)
        
        # Compute per-sample IoU for ortho breakdown
        for i, filename in enumerate(filenames):
            per_sample_iou.reset()
            per_sample_iou.update(preds[i:i+1], labels[i:i+1])
            iou_per_class = per_sample_iou.compute()
            
            results.append({
                "filename": filename,
                "ortho": extract_ortho(filename),
                "iou_target": iou_per_class[1].item() if len(iou_per_class) > 1 else iou_per_class[0].item(),
            })
    
    # Compute final metrics
    final_metrics = {name: m.compute() for name, m in metrics.items()}
    
    return final_metrics, pd.DataFrame(results)

## Evaluate Validation Set

In [None]:
val_metrics = create_metrics(num_classes, ignore_index, DEVICE)
val_results, val_df = evaluate(model, val_loader, val_metrics, DEVICE, num_classes)

In [None]:
print("=" * 50)
print("VALIDATION SET METRICS")
print("=" * 50)
print(f"\nPer-class metrics:")
for i, name in enumerate(class_names):
    print(f"\n  {name}:")
    print(f"    IoU:       {val_results['iou'][i]:.4f}")
    print(f"    Precision: {val_results['precision'][i]:.4f}")
    print(f"    Recall:    {val_results['recall'][i]:.4f}")
    print(f"    F1:        {val_results['f1'][i]:.4f}")

# Mean across non-background classes (matches W&B val/iou_epoch)
mean_iou = val_results['iou'][1:].mean().item()
print(f"\n{'='*50}")
print(f"Mean IoU (non-bg): {mean_iou:.4f}  <- Should match W&B val/iou_epoch")
print(f"{'='*50}")

## Evaluate Test Set

In [None]:
test_metrics = create_metrics(num_classes, ignore_index, DEVICE)
test_results, test_df = evaluate(model, test_loader, test_metrics, DEVICE, num_classes)

In [None]:
print("=" * 50)
print("TEST SET METRICS")
print("=" * 50)
print(f"\nPer-class metrics:")
for i, name in enumerate(class_names):
    print(f"\n  {name}:")
    print(f"    IoU:       {test_results['iou'][i]:.4f}")
    print(f"    Precision: {test_results['precision'][i]:.4f}")
    print(f"    Recall:    {test_results['recall'][i]:.4f}")
    print(f"    F1:        {test_results['f1'][i]:.4f}")

mean_iou = test_results['iou'][1:].mean().item()
print(f"\n{'='*50}")
print(f"Mean IoU (non-bg): {mean_iou:.4f}")
print(f"{'='*50}")

## Per-Ortho Breakdown

In [None]:
# Aggregate by ortho
def ortho_summary(df):
    return df.groupby("ortho").agg(
        iou_mean=("iou_target", "mean"),
        iou_std=("iou_target", "std"),
        iou_min=("iou_target", "min"),
        iou_max=("iou_target", "max"),
        n_tiles=("iou_target", "count"),
    ).sort_values("iou_mean", ascending=False).round(4)

print("VALIDATION - Per-Ortho IoU (target class)")
print("=" * 70)
display(ortho_summary(val_df))

In [None]:
print("TEST - Per-Ortho IoU (target class)")
print("=" * 70)
display(ortho_summary(test_df))

## Summary Comparison

In [None]:
# Create summary table
target_idx = 1  # seagrass class

summary = pd.DataFrame({
    "Metric": ["IoU", "Precision", "Recall", "F1"],
    "Validation": [
        val_results["iou"][target_idx].item(),
        val_results["precision"][target_idx].item(),
        val_results["recall"][target_idx].item(),
        val_results["f1"][target_idx].item(),
    ],
    "Test": [
        test_results["iou"][target_idx].item(),
        test_results["precision"][target_idx].item(),
        test_results["recall"][target_idx].item(),
        test_results["f1"][target_idx].item(),
    ],
}).round(4)

print(f"\nSUMMARY - {class_names[target_idx]} class metrics")
print("=" * 50)
display(summary)

## Visualize Per-Ortho Performance

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, (name, df) in zip(axes, [("Validation", val_df), ("Test", test_df)]):
    ortho_stats = ortho_summary(df)
    
    x = range(len(ortho_stats))
    ax.bar(x, ortho_stats["iou_mean"], yerr=ortho_stats["iou_std"], capsize=3, color="steelblue", alpha=0.8)
    ax.axhline(ortho_stats["iou_mean"].mean(), color="red", linestyle="--", label=f"Mean: {ortho_stats['iou_mean'].mean():.3f}")
    
    ax.set_xticks(x)
    ax.set_xticklabels(ortho_stats.index, rotation=45, ha="right", fontsize=8)
    ax.set_ylabel("IoU")
    ax.set_ylim(0, 1)
    ax.set_title(f"{name} - IoU by Ortho")
    ax.legend()
    ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()