# Recompute validation metrics from a checkpoint

This notebook loads a trained world model checkpoint and runs a validation
pass on the validation split to recompute loss metrics. Update the paths
in the first code cell before running.


In [None]:
from pathlib import Path
from collections import defaultdict

import hydra
import torch
from omegaconf import OmegaConf

import custom_resolvers  # registers OmegaConf resolvers
from plan import load_model

# ---- update these paths ----
MODEL_DIR = Path("/path/to/your/output/run").expanduser()
CHECKPOINT_NAME = "model_latest.pth"  # or e.g. model_10.pth
BATCH_SIZE = 64  # set smaller if running on CPU
NUM_WORKERS = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Load the training config saved alongside the checkpoint
model_cfg = OmegaConf.load(MODEL_DIR / "hydra.yaml")

# Build datasets (same as training)
datasets, traj_dsets = hydra.utils.call(
    model_cfg.env.dataset,
    num_hist=model_cfg.model.num_hist,
    num_pred=model_cfg.model.num_pred,
    frameskip=model_cfg.dataset.frameskip,
)

val_dataset = datasets["valid"]
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

print(f"Validation samples: {len(val_dataset)}")


In [None]:
# Load model from checkpoint
model_ckpt = MODEL_DIR / "checkpoints" / CHECKPOINT_NAME
cfg_dict = {
    "plan_action_type": str(model_cfg.model.plan_action_type),
    "action_dim": int(val_dataset.action_dim),
    "is_training": False,
}

num_action_repeat = int(model_cfg.model.num_action_repeat)
required_keys = {"model"}

model, _ = load_model(
    model_ckpt=model_ckpt,
    model_cfg=model_cfg,
    cfg_dict=cfg_dict,
    num_action_repeat=num_action_repeat,
    required_keys=required_keys,
    device=device,
)

model = model.to(device)
model.eval()
print("Model loaded and set to eval mode.")


In [None]:
from einops import rearrange
from utils import slice_trajdict_with_t

def err_eval_single(model, z_pred, z_tgt):
    logs = {}
    for k in z_pred.keys():
        logs[k] = model.emb_criterion(z_pred[k], z_tgt[k])
    return logs

def openloop_rollout(model, dset, cfg, num_rollout=10, rand_start_end=True, min_horizon=5):
    torch.manual_seed(cfg.training.seed)
    min_horizon = min_horizon + cfg.model.num_hist
    logs = {}

    num_past = [(cfg.model.num_hist, ""), (1, "_1framestart")]

    for idx in range(num_rollout):
        valid_traj = False
        while not valid_traj:
            traj_idx = torch.randint(0, len(dset), ()).item()
            obs, act, _, _ = dset[traj_idx]
            act = act.to(device)
            if rand_start_end:
                if obs["visual"].shape[0] > min_horizon * cfg.dataset.frameskip + 1:
                    start = torch.randint(0, obs["visual"].shape[0] - min_horizon * cfg.dataset.frameskip - 1, ()).item()
                else:
                    start = 0
                max_horizon = (obs["visual"].shape[0] - start - 1) // cfg.dataset.frameskip
                if max_horizon > min_horizon:
                    valid_traj = True
                    horizon = torch.randint(min_horizon, max_horizon + 1, ()).item()
            else:
                valid_traj = True
                start = 0
                horizon = (obs["visual"].shape[0] - 1) // cfg.dataset.frameskip

        for k in obs.keys():
            obs[k] = obs[k][
                start : start + horizon * cfg.dataset.frameskip + 1 : cfg.dataset.frameskip
            ]
        act = act[start : start + horizon * cfg.dataset.frameskip]
        act = rearrange(act, "(h f) d -> h (f d)", f=cfg.dataset.frameskip)

        obs_g = {k: obs[k][-1].unsqueeze(0).unsqueeze(0).to(device) for k in obs.keys()}
        z_g = model.encode_obs(obs_g)
        actions = act.unsqueeze(0) if model.use_action_encoder else None

        for n_past, postfix in num_past:
            obs_full = {k: obs[k].unsqueeze(0).to(device) for k in obs.keys()}
            z_obses, _ = model.rollout(obs_full, actions, num_obs_init=n_past)
            z_obs_last = slice_trajdict_with_t(z_obses, start_idx=-1, end_idx=None)
            div_loss = err_eval_single(model, z_obs_last, z_g)

            for k in div_loss.keys():
                log_key = f"z_{k}_err_rollout{postfix}"
                logs.setdefault(log_key, []).append(div_loss[k])

    logs = {key: sum(values) / len(values) for key, values in logs.items() if values}
    return logs

rollout_logs = openloop_rollout(model, traj_dsets["valid"], model_cfg, num_rollout=10, rand_start_end=True)
rollout_logs = {f"val_{k}": (v.detach().mean().cpu().item() if torch.is_tensor(v) else float(v)) for k, v in rollout_logs.items()}

print("Open-loop rollout metrics (mean over rollouts):")
for key, value in sorted(rollout_logs.items()):
    print(f"  {key}: {value:.6f}")


In [None]:
def move_to_device(value, device):
    if torch.is_tensor(value):
        return value.to(device)
    if isinstance(value, dict):
        return {k: move_to_device(v, device) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return type(value)(move_to_device(v, device) for v in value)
    return value

totals = defaultdict(float)
total_samples = 0

with torch.no_grad():
    for obs, act, _ in val_loader:
        obs = move_to_device(obs, device)
        act = move_to_device(act, device)

        _, _, _, loss, loss_components, _ = model(obs, act)

        batch_size = act.shape[0]
        total_samples += batch_size
        totals["loss"] += loss.item() * batch_size

        for key, value in loss_components.items():
            totals[key] += value.item() * batch_size

avg_metrics = {k: v / total_samples for k, v in totals.items()}

print("Validation metrics (mean over validation set):")
for key, value in sorted(avg_metrics.items()):
    print(f"  {key}: {value:.6f}")
