## AutoCast Processor Evaluation

This notebook evaluates a pre-trained processor model on the MiniWell dataset.
It loads the model configuration and weights from a specified run directory.


In [None]:
import os

import matplotlib.pyplot as plt
import torch
from hydra.utils import instantiate
from IPython.display import HTML
from omegaconf import OmegaConf

from autocast.models.processor import ProcessorModel
from autocast.utils.plots import plot_spatiotemporal_video

# device = "mps"  # or "cpu"
# device = "cpu"
device = "cuda"

In [None]:
from pathlib import Path

import wandb
from hydra import compose, initialize_config_dir

# Retrieve run config from wandb since the run finished before the config was saved
api = wandb.Api()
run = api.run("turing-core/autocast/runs/j7x1q8xq")

# Get resolved config as a dict
resolved_config = dict(run.config)

# Convert to OmegaConf
cfg = OmegaConf.create(resolved_config)["hydra"]

cfg.data.data_path = (
    "../datasets/rayleigh_benard/1e3z5x2c_rayleigh_benard_dcae_f32c64_large/"
    "cache/rayleigh_benard"
)
print(OmegaConf.to_yaml(cfg))

In [None]:
# Prepare datamodule and configure processor dimensions

from autocast.train.processor import (
    configure_processor_dimensions,
    prepare_encoded_datamodule,
)

(
    datamodule,
    in_channel_count,
    out_channel_count,
    global_cond_channels,
    inferred_n_steps_input,
    inferred_n_steps_output,
    input_shape,
    output_shape,
    _example_batch,
) = prepare_encoded_datamodule(cfg)

configure_processor_dimensions(
    cfg,
    in_channel_count=in_channel_count,
    out_channel_count=out_channel_count,
    global_cond_channels=global_cond_channels,
    n_steps_input=inferred_n_steps_input,
    n_steps_output=inferred_n_steps_output,
)

In [None]:
# Instantiate latent datamodule and setup

datamodule = instantiate(cfg.data)
datamodule.setup("test")

In [None]:
# Instantiate Processor
processor = instantiate(cfg.model.processor)

# Construct ProcessorModelWrapper
model = ProcessorModel(processor=processor, learning_rate=cfg.model.learning_rate)

In [None]:
# Path to the run directory
run_path = "../outputs/processor/20260117_040704/autocast/j7x1q8xq/"
# config_path = os.path.join(run_path, "resolved_processor_config.yaml")
ckpt_path = os.path.join(run_path, "checkpoints/step-step=60000.ckpt")
# ckpt_path = os.path.join(run_path, "processor.ckpt")

In [None]:
# Load checkpoint
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("Model loaded successfully")

In [None]:
# Load AutoEncoder to decode predictions

ae_path = "../datasets/rayleigh_benard/1e3z5x2c_rayleigh_benard_dcae_f32c64_large"
ae_config_path = os.path.join(ae_path, "config.yaml")
ae_ckpt_path = os.path.join(ae_path, "state.pth")

print(f"Loading AutoEncoder from: {ae_path}")
ae_cfg = OmegaConf.load(ae_config_path)

# Convert to dict to avoid OmegaConf/beartype conflicts for args (e.g. attention_heads)
ae_config_dict = OmegaConf.to_container(ae_cfg.ae, resolve=True)

# However, get_autoencoder specifically types 'loss' as DictConfig, so preserve that
if "loss" in ae_cfg.ae:
    ae_config_dict["loss"] = ae_cfg.ae.loss  # type: ignore  # noqa: PGH003

In [None]:
# Get dataset mean and std for normalization

mean = torch.tensor(ae_cfg.dataset.stats.mean)
std = torch.tensor(ae_cfg.dataset.stats.std)
mean, std

In [None]:
# Initialize Encoder and Decoder and load weights from ae_path

from autocast.external.lola.wrapped_decoder import WrappedDecoder
from autocast.external.lola.wrapped_encoder import WrappedEncoder
from autocast.models.autoencoder import AE

encoder = WrappedEncoder(
    device=device,
    runpath=ae_path,
    mean=mean,
    std=std,
    **ae_config_dict, # type: ignore  # noqa: PGH003
)
decoder = WrappedDecoder(
    device=device,
    runpath=ae_path,
    mean=mean,
    std=std,
    **ae_config_dict, # type: ignore  # noqa: PGH003
)
ae = AE(encoder=encoder, decoder=decoder)
_ = ae.eval()

In [None]:
# Ambient dataloader

with initialize_config_dir(
    version_base=None, config_dir=str(Path.cwd() / "../configs/")
):
    data_cfg = compose(
        config_name="data/the_well",
        overrides=["data.well_dataset_name=rayleigh_benard"],
    )["data"]
    data_cfg.batch_size = 2
    # TODO: for the moment handle the normalization at encoder/decoder level
    # data_cfg.use_normalization = True

ambient_datamodule = instantiate(data_cfg)
ambient_datamodule.setup("test")

In [None]:
# Construct EncoderProcessorDecoder

from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder

epd = EncoderProcessorDecoder(encoder_decoder=ae, processor=model.processor)

In [None]:
# Get a batch

batch = ambient_datamodule.rollout_test_dataloader().__iter__().__next__()

In [None]:
# Check autoencoder reconstruction

output_recon = ae(batch)
fig, axs = plt.subplots(1, 3)
im0 = axs[0].imshow(batch.input_fields[0, 0, :, :, 0].detach().cpu().numpy())
axs[0].set_title("input")
fig.colorbar(im0, ax=axs[0])
im1 = axs[1].imshow(output_recon[0, 0, :, :, 0].detach().cpu().numpy())
axs[1].set_title("recon")
fig.colorbar(im1, ax=axs[1])
diff = torch.abs(batch.input_fields - output_recon)
im2 = axs[2].imshow(diff[0, 0, :, :, 0].detach().cpu().numpy())
axs[2].set_title("diff")
fig.colorbar(im2, ax=axs[2])
plt.tight_layout()
plt.show()

In [None]:
# Run rollout on a the batch of trajectories

rollout_stride = 4
preds, trues = epd.rollout(
    batch,
    stride=rollout_stride,
    max_rollout_steps=25,
    free_running_only=True,
)
print(preds.shape)
assert trues is not None
print(trues.shape)

In [None]:
# Test metrics computed

from autocast.metrics import MSE

assert trues is not None
assert preds.shape == trues.shape
mse = MSE()
mse_error = mse(preds, trues).detach().cpu().item()
print(f"MSE overall as a single scalar: {mse_error:.3f}")

In [None]:
# Construct spatiotemporal video

batch_idx = 0
metadata = ambient_datamodule.train_dataset.well_metadata
simulation_name = metadata.dataset_name
anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    batch_idx=batch_idx,
    save_path=f"{simulation_name}_{batch_idx:02d}.mp4",
    colorbar_mode="column",
    channel_names=[v for value in metadata.field_names.values() for v in value],
)

In [None]:
# Plot spatiotemporal video

HTML(anim.to_jshtml())