## AutoCast end-to-end training and evaluation example

This notebook demonstrates end-to-end training of:

- autoencoder
- flow matching
- evaluation


### Example dataaset

We use the `ReactionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [None]:
from autocast.metrics import MAE, MSE, RMSE

THE_WELL = False
simulation_name = "advection_diffusion_multichannel"
n_steps_input = 1
n_steps_output = 4
stride = 1
rollout_stride = 4


### Read combined data into datamodule


In [None]:
from autocast.data.utils import get_datamodule

ae_datamodule = get_datamodule(
    the_well=THE_WELL,
    simulation_name=simulation_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=stride,
    autoencoder_mode=True,
    num_workers=0,
)
datamodule = get_datamodule(
    the_well=THE_WELL,
    simulation_name=simulation_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=stride,
    autoencoder_mode=False,
    num_workers=0,
)

### Set-up logging


In [None]:
from autocast.logging import maybe_watch_model
from autocast.logging.wandb import create_notebook_logger

logger, watch = create_notebook_logger(
    project="autocast-notebooks",
    name=f"04_e2e_{simulation_name}",
    tags=["notebook", simulation_name],
    enabled=False,
)

In [None]:
batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
w, h = batch.input_fields.shape[2:4]

In [None]:
# Compress by a factor of 4, with the spatial autoencoder reducing to 16x16 patches with
# n_latent channels
compression = 4
latent_width, latent_height, latent_stride = 16, 16, 2
n_latent = (w * h * n_channels // (latent_width * latent_height)) // compression

print(f"n_latent channels equals {n_latent} for a compression factor of {compression}")

In [None]:
# Train autoencoder
from autocast.decoders.dc import DCDecoder
from autocast.encoders.dc import DCEncoder
from autocast.models.autoencoder import AE

encoder = DCEncoder(
    in_channels=n_channels,
    out_channels=n_latent,
    # hid_channels=[32, 64],
    # hid_blocks=[2, 2],
    # e.g. for extra factor of 2 downsampling
    hid_channels=[32, 64, 128],
    hid_blocks=[2, 2, 2],
    kernel_size=3,
    stride=2,
    spatial=2,
    pixel_shuffle=False,
    periodic=False,
    dropout=None,
)
decoder = DCDecoder(
    in_channels=n_latent,
    out_channels=n_channels,
    # hid_channels=[64, 32],
    # hid_blocks=[2, 2],
    # e.g. for extra factor of 2 downsampling
    hid_channels=[128, 64, 32],
    hid_blocks=[2, 2, 2],
    kernel_size=3,
    stride=2,
    spatial=2,
    pixel_shuffle=False,
    periodic=False,
    dropout=None,
)

ae = AE(encoder=encoder, decoder=decoder, learning_rate=5e-4)

In [None]:
encoded, global_cond = ae.encoder.encode_with_cond(next(iter(ae_datamodule.train_dataloader())))


In [None]:
print("Encoded shape is:", tuple(encoded.shape))

In [None]:
import lightning as L

device = "mps"  # "cpu"
trainer = L.Trainer(max_epochs=2, accelerator=device, logger=logger)
trainer.fit(ae, ae_datamodule.train_dataloader(), ae_datamodule.val_dataloader())
trainer.save_checkpoint(f"./{simulation_name}_ae_model.ckpt")

In [None]:
import matplotlib.pyplot as plt

device = "cpu"
num_examples = 2
for idx, batch in enumerate(ae_datamodule.test_dataloader()):
    inputs = batch.input_fields.to(device)
    outputs, latents = ae.forward_with_latent(batch)
    print("Input shape:", inputs.shape)
    print("Output shape:", outputs.shape)
    print("Latent shape:", latents.shape)
    input_frame = inputs[0, 0, :, :, 1].detach().cpu().numpy()
    output_frame = outputs[0, 0, :, :, 1].detach().cpu().numpy()
    latent_frame = latents[0, 0, :, :, 0].detach().cpu().numpy()
    fig, axs = plt.subplots(1, 4, figsize=(10, 4))
    axs[0].imshow(input_frame, cmap="viridis")
    axs[0].set_title("Input")
    axs[1].imshow(output_frame, cmap="viridis")
    axs[1].set_title("Reconstruction")
    axs[2].imshow(output_frame - input_frame, cmap="viridis")
    axs[2].set_title("Difference")
    axs[3].imshow(latent_frame, cmap="viridis")
    axs[3].set_title("Latent")
    plt.tight_layout()
    plt.show()

    if idx + 1 >= num_examples:
        break


### Example shape and batch


In [None]:
datamodule.train_dataset[0].input_fields.shape

In [None]:
batch = next(iter(datamodule.train_dataloader()))

batch.input_fields.shape

In [None]:
from azula.noise import VPSchedule

from autocast.metrics.deterministic import VRMSE
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
from autocast.nn.base import TemporalBackboneBase
from autocast.nn.unet import TemporalUNetBackbone
from autocast.nn.vit import TemporalViTBackbone
from autocast.processors.diffusion import DiffusionProcessor
from autocast.processors.flow_matching import FlowMatchingProcessor

# Get sample batch

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
example_global_cond = encoder.encode_cond(batch)
global_cond_channels = (
    example_global_cond.shape[-1]
    if example_global_cond is not None
    else None
)

# Construct backbone

def get_backbone(backbone_name: str) -> TemporalBackboneBase:
    """Create backbone based on name."""
    if backbone_name == "unet":
        backbone = TemporalUNetBackbone(
            in_channels=n_channels,
            out_channels=n_channels,
            cond_channels=n_channels,
            n_steps_output=n_steps_output,
            n_steps_input=n_steps_input,
            global_cond_channels=global_cond_channels,
            mod_features=200,
            hid_channels=(32, 64, 128),
            hid_blocks=(2, 2, 2),
            spatial=2,
            periodic=False,
        )
    elif backbone_name == "vit":
        backbone = TemporalViTBackbone(
            in_channels=n_channels,
            out_channels=n_channels,
            cond_channels=n_channels,
            n_steps_output=n_steps_output,
            n_steps_input=n_steps_input,
            mod_features=256,
            global_cond_channels=global_cond_channels,
            # hid_channels=768,
            hid_channels=512,
            # hid_blocks=12,
            hid_blocks=8,
            temporal_method="none",
            # temporal_method="attention",
            # attention_heads=12,
            attention_heads=8,
            spatial=2,
            patch_size=1,
            dropout=0.05,
            ffn_factor=4,
            checkpointing=True,
        )
    else:
        raise ValueError(f"Unknown backbone name: {backbone_name}")
    return backbone


backbone_name = "vit"  # set to "unet" or "vit"
backbone = get_backbone(backbone_name)


# Construct processor

processor_name = "flow_matching"  # set to "diffusion" to compare
# processor_name = "diffusion"  # set to "flow_matching" to compare

if processor_name == "flow_matching":
    processor = FlowMatchingProcessor(
        backbone=backbone,
        n_steps_output=n_steps_output,
        n_channels_out=n_latent,
        flow_ode_steps=20,
    )
else:
    processor = DiffusionProcessor(
        backbone=backbone,
        schedule=VPSchedule(),
        n_steps_output=n_steps_output,
        n_channels_out=n_latent,
    )

model = EncoderProcessorDecoder(
    encoder_decoder=ae,
    processor=processor,
    train_in_latent_space=True,
    learning_rate=5e-4,
    val_metrics=[VRMSE(), MSE(), MAE(), RMSE()],
    test_metrics=[VRMSE(), MSE(), MAE(), RMSE()],
)
maybe_watch_model(logger, model, watch)


### Run trainer


In [None]:
import lightning as L

device = "mps"  # "cpu"
trainer = L.Trainer(max_epochs=5, accelerator=device, logger=logger)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
trainer.save_checkpoint(f"./{simulation_name}_{processor_name}_model.ckpt")

### Run the evaluation


In [None]:
trainer.test(model, datamodule.test_dataloader())

### Example rollout


In [None]:
# A single element is the full trajectory
datamodule = get_datamodule(
    the_well=THE_WELL,
    simulation_name=simulation_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=stride,
    autoencoder_mode=False,
    batch_size=2,
    num_workers=0,
)

batch = next(iter(datamodule.rollout_test_dataloader()))

In [None]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

In [None]:
from autocast.models.encoder_processor_decoder_ensemble import (
    EncoderProcessorDecoderEnsemble,
)

ensemble_model = EncoderProcessorDecoderEnsemble(
    encoder_decoder=model.encoder_decoder,
    processor=model.processor,
    train_in_latent_space=False,
    learning_rate=5e-4,
    test_metrics = [],
    val_metrics = [],
    n_members=5,
    batch_size=2,
)


In [None]:
# Run rollout on one trajectory
preds, trues = ensemble_model.rollout(
    batch, stride=rollout_stride, max_rollout_steps=80, free_running_only=True,
    n_members=5,
)

print(preds.shape) # B, T, H, W, C, M
assert trues is not None
print(trues.shape) # B, T, H, W, C


In [None]:
from autocast.metrics import MSE

assert trues is not None
assert preds.shape == trues.shape
mse = MSE()
mse_error = mse(preds, trues)
print("MSE overall is a single scalar:", mse_error)

In [None]:
from IPython.display import HTML

from autocast.utils import plot_spatiotemporal_video

batch_idx = 0
if simulation_name == "advection_diffusion_multichannel":
    channel_names = ["vorticity", "velocity_x", "velocity_y", "streamfunction"]
elif simulation_name == "advection_diffusion":
    channel_names = ["vorticity"]
elif simulation_name == "reaction_diffusion":
    channel_names = ["U", "V"]
else:
    channel_names = None

anim = plot_spatiotemporal_video(
    pred=preds.mean(dim=-1),
    true=trues[..., 0],
    pred_uq=preds.std(dim=-1),
    batch_idx=batch_idx,
    save_path=f"{simulation_name}_{batch_idx:02d}.mp4",
    colorbar_mode="column",
    channel_names=channel_names,
    pred_uq_label="Ensemble Std. Dev.",
    colorbar_mode_uq="row",
)

In [None]:
# Plot in notebook
HTML(anim.to_jshtml())