# Recompute validation metrics from a checkpoint

This notebook loads a trained world model checkpoint and runs a validation
pass on the validation split to recompute loss metrics. Update the paths
in the first code cell before running.


In [None]:
from pathlib import Path
from collections import defaultdict

import hydra
import torch
from omegaconf import OmegaConf

import custom_resolvers  # registers OmegaConf resolvers
from plan import load_model

# ---- update these paths ----
MODEL_DIR = Path("/path/to/your/output/run").expanduser()
CHECKPOINT_NAME = "model_latest.pth"  # or e.g. model_10.pth
BATCH_SIZE = 64  # set smaller if running on CPU
NUM_WORKERS = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Load the training config saved alongside the checkpoint
model_cfg = OmegaConf.load(MODEL_DIR / "hydra.yaml")

# Build datasets (same as training)
datasets, traj_dsets = hydra.utils.call(
    model_cfg.env.dataset,
    num_hist=model_cfg.model.num_hist,
    num_pred=model_cfg.model.num_pred,
    frameskip=model_cfg.dataset.frameskip,
)

val_dataset = datasets["valid"]
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(f"Validation samples: {len(val_dataset)}")


In [None]:
# Load model from checkpoint
model_ckpt = MODEL_DIR / "checkpoints" / CHECKPOINT_NAME
cfg_dict = {
    "plan_action_type": str(model_cfg.model.plan_action_type),
    "action_dim": int(val_dataset.action_dim),
    "is_training": False,
}

num_action_repeat = int(model_cfg.model.num_action_repeat)
required_keys = {"model"}

model, _ = load_model(
    model_ckpt=model_ckpt,
    model_cfg=model_cfg,
    cfg_dict=cfg_dict,
    num_action_repeat=num_action_repeat,
    required_keys=required_keys,
    device=device,
)

model = model.to(device)
model.eval()
print("Model loaded and set to eval mode.")


In [None]:
def move_to_device(value, device):
    if torch.is_tensor(value):
        return value.to(device)
    if isinstance(value, dict):
        return {k: move_to_device(v, device) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return type(value)(move_to_device(v, device) for v in value)
    return value

totals = defaultdict(float)
total_samples = 0

with torch.no_grad():
    for obs, act, _ in val_loader:
        obs = move_to_device(obs, device)
        act = move_to_device(act, device)

        _, _, _, loss, loss_components, _ = model(obs, act)

        batch_size = act.shape[0]
        total_samples += batch_size
        totals["loss"] += loss.item() * batch_size

        for key, value in loss_components.items():
            totals[key] += value.item() * batch_size

avg_metrics = {k: v / total_samples for k, v in totals.items()}

print("Validation metrics (mean over validation set):")
for key, value in sorted(avg_metrics.items()):
    print(f"  {key}: {value:.6f}")
