# Simple Model Evaluation

Load one model, compute IoU on validation set. Should match W&B `val/iou_epoch/seagrass`.

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
import yaml
import albumentations as A
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torchmetrics

from src.data import NpzSegmentationDataset
from src.models.smp import SMPMulticlassSegmentationModel

In [None]:
# CONFIG - edit these paths
CONFIG_PATH = "../configs/seagrass-rgb/architecture-experiment/segformer_mitb2_1024.yaml"
CKPT_PATH = "/mnt/class_data/sdalgarno/checkpoints/architecture-experiment/segformer-1024/last.ckpt"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

In [None]:
# Load config
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)

model_args = config["model"]["init_args"]
data_args = config["data"]["init_args"]

print(f"Model: {model_args['architecture']} / {model_args['backbone']}")
print(f"Val dir: {data_args['val_chip_dir']}")

In [None]:
# Load model
model = SMPMulticlassSegmentationModel(**model_args)
checkpoint = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
model.to(DEVICE)

print("Model loaded")

In [None]:
# Load validation dataset
test_transforms = A.from_dict(data_args["test_transforms"])
val_dataset = NpzSegmentationDataset(data_args["val_chip_dir"], transforms=test_transforms)
val_loader = DataLoader(val_dataset, batch_size=8, num_workers=4, shuffle=False)

print(f"Validation tiles: {len(val_dataset)}")

In [None]:
# Create metric - same as training
num_classes = model_args["num_classes"]
ignore_index = model_args.get("ignore_index", -100)

iou_metric = torchmetrics.JaccardIndex(
    task="multiclass",
    num_classes=num_classes,
    ignore_index=ignore_index,
    average="none"
).to(DEVICE)

print(f"Num classes: {num_classes}, Ignore index: {ignore_index}")

In [None]:
# Run evaluation
model.eval()
with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Evaluating"):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        
        logits = model(images)
        preds = logits.argmax(dim=1)
        
        iou_metric.update(preds, labels)

# Compute final IoU
iou_per_class = iou_metric.compute()

print(f"\nIoU per class: {iou_per_class}")
print(f"\n=== IoU (seagrass): {iou_per_class[1].item():.4f} ===")
print(f"=== IoU (background): {iou_per_class[0].item():.4f} ===")