## AutoCast Processor Evaluation

This notebook evaluates a pre-trained processor model on the MiniWell dataset.
It loads the model configuration and weights from a specified run directory.


In [None]:
import os

import matplotlib.pyplot as plt
import torch
from hydra.utils import instantiate
from IPython.display import HTML
from omegaconf import OmegaConf

from autocast.external.lola.lola_autoencoder import get_autoencoder
from autocast.models.processor import ProcessorModel
from autocast.utils.plots import plot_spatiotemporal_video

device = "mps"  # or "cpu"
# device = "cpu"

In [None]:
# Path to the run directory
run_path = "../outputs/rayleigh_benard/2026-01-14_diffusion_vit_small"
config_path = os.path.join(run_path, "resolved_processor_config.yaml")
ckpt_path = os.path.join(run_path, "processor.ckpt")

# Load configuration
cfg = OmegaConf.load(config_path)
# print(OmegaConf.to_yaml(cfg))

In [None]:
# Instantiate DataModule and setup
datamodule = instantiate(cfg.data)
datamodule.setup()  # Setup all stages (fit for train/val, test for test)

In [None]:
# Instantiate Processor
processor = instantiate(cfg.model.processor)

# Instantiate ProcessorModelWrapper
model = ProcessorModel(processor=processor, learning_rate=cfg.model.learning_rate)

In [None]:
# Load checkpoint
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("Model loaded successfully")

In [None]:
# Do rollout prediction first
from autocast.types.batch import EncodedBatch

batch = next(iter(datamodule.val_dataloader()))

# Reduce batch size for faster inference
batch = EncodedBatch(
    encoded_inputs=batch.encoded_inputs[:1],
    encoded_output_fields=batch.encoded_output_fields[:1],
    global_cond=batch.global_cond[:1] if batch.global_cond is not None else None,
    encoded_info={},
)
with torch.no_grad():
    preds_free_running = model.rollout(
        batch,
        stride=batch.encoded_output_fields.shape[1],
        max_rollout_steps=20,
        free_running_only=True,
    )

In [None]:
# Create rollout with teacher forcing
with torch.no_grad():
    preds, trues = [], []
    for i, batch in enumerate(datamodule.val_dataloader()):
        pred = model(batch.encoded_inputs, batch.global_cond)
        preds.append(pred)
        trues.append(batch.encoded_output_fields)
        if i >= 3:  # Limit for demonstration purposes
            break
    preds = torch.cat(preds, dim=0)
    trues = torch.cat(trues, dim=0)

print(f"Predictions shape: {preds.shape}")
print(f"Ground Truth shape: {trues.shape}")

In [None]:
from einops import rearrange

max_rollout_steps = 15
dataset_stride: int = datamodule.stride
indices = torch.arange(0, max_rollout_steps) * dataset_stride
indices = indices[indices < trues.shape[0]]
# trues_rollout = rearrange(trues[indices, 0], "B ... C -> 1 B ... C")
# preds_rollout = rearrange(preds[indices, 0], "B ... C -> 1 B ... C")
trues_rollout = rearrange(trues[indices[::4]], "B T ... C -> 1 (B T) ... C")
preds_rollout = rearrange(preds[indices[::4]], "B T ... C -> 1 (B T) ... C")
print(f"Constructed Ground Truth shape: {trues_rollout.shape}")

In [None]:
anim = plot_spatiotemporal_video(
    true=trues_rollout[..., :4] if trues_rollout is not None else None,
    pred=preds_rollout[..., :4],
    batch_idx=0,
    save_path="teacher_forcing_prediction.mp4",
    title="Teacher Forcing (latent)",
    colorbar_mode="row",
)
HTML(anim.to_jshtml())

In [None]:
# Load AutoEncoder to decode predictions
ae_path = "../datasets/rayleigh_benard/1e3z5x2c_rayleigh_benard_dcae_f32c64_large"
ae_config_path = os.path.join(ae_path, "config.yaml")
ae_ckpt_path = os.path.join(ae_path, "state.pth")

print(f"Loading AutoEncoder from: {ae_path}")
ae_cfg = OmegaConf.load(ae_config_path)

# Convert to dictionary to avoid OmegaConf/beartype conflicts for most args (like attention_heads)
ae_config_dict = OmegaConf.to_container(ae_cfg.ae, resolve=True)

# However, get_autoencoder specifically types 'loss' as DictConfig, so we must preserve it
if "loss" in ae_cfg.ae:
    ae_config_dict["loss"] = ae_cfg.ae.loss

In [None]:
from autocast.external.lola.wrapped_decoder import WrappedDecoder
from autocast.external.lola.wrapped_encoder import WrappedEncoder
from autocast.models.autoencoder import AE

encoder = WrappedEncoder(**ae_config_dict)  # type: ignore  # noqa: PGH003
decoder = WrappedDecoder(device=device, batch_size=4, runpath=ae_path, **ae_config_dict)  # type: ignore  # noqa: PGH003
ae = AE(encoder=encoder, decoder=decoder)
_ = ae.eval()

In [None]:
# Decode
with torch.no_grad():
    trues_decoded = ae.decode(trues_rollout)
    preds_decoded = ae.decode(preds_rollout)
    preds_free_running_decoded = ae.decode(preds_free_running[0])

    # Swap to (B, T, H, W, C) and flip for plotting
    trues_decoded = rearrange(trues_decoded, "1 T H W C -> 1 T W H C").flip(-3)
    preds_decoded = rearrange(preds_decoded, "1 T H W C -> 1 T W H C").flip(-3)
    preds_free_running_decoded = rearrange(
        preds_free_running_decoded, "1 T H W C -> 1 T W H C"
    ).flip(-3)

In [None]:
# Plot decoded
anim = plot_spatiotemporal_video(
    true=trues_decoded,
    pred=preds_decoded,
    batch_idx=0,
    save_path="teacher_forcing_decoded_prediction.mp4",
    title="Teacher forcing (decoded) prediction",
    colorbar_mode="row",
)
HTML(anim.to_jshtml())

In [None]:
# Plot decoded
anim = plot_spatiotemporal_video(
    true=trues_decoded,
    pred=preds_free_running_decoded[:, : trues_decoded.shape[1]],
    batch_idx=0,
    save_path="free_running_decoded_prediction.mp4",
    title="Free running (decoded) prediction",
    colorbar_mode="row",
)
HTML(anim.to_jshtml())