## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [None]:
import pickle
from pathlib import Path

from autoemulate.simulations.advection_diffusion import AdvectionDiffusion
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion

from autocast.data.advection_diffusion import (
    AdvectionDiffusion as AdvectionDiffusionMultichannel,
)
from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule
from autocast.metrics.spatiotemporal import MAE, MSE, RMSE

THE_WELL = False
n_steps_input = 4
n_steps_output = 1
stride = 1
rollout_stride = n_steps_output
simulation_name = "reaction_diffusion"

### Read combined data into datamodule

In [None]:
from autocast.data.utils import get_datamodule

datamodule = get_datamodule(
    the_well=THE_WELL,
    simulation_name=simulation_name,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=stride,
    n_train=50,
    n_valid=10,
    n_test=10
)

### Set-up logging

In [None]:

from autocast.logging import create_wandb_logger, maybe_watch_model
from autocast.logging.wandb import create_notebook_logger

logger, watch = create_notebook_logger(
    project="autocast-notebooks",
    name=f"00_01_exploration_{simulation_name}",
    tags=["notebook", simulation_name],
)

### Example shape and batch


In [None]:
batch = next(iter(datamodule.train_dataloader()))



In [None]:
n_constant_scalars = batch.constant_scalars.shape[-1]

In [None]:
from azula.noise import VPSchedule

from autocast.decoders.identity import IdentityDecoder
from autocast.encoders.identity import IdentityEncoder
from autocast.models.encoder_decoder import EncoderDecoder
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
from autocast.nn.unet import TemporalUNetBackbone
from autocast.processors.flow_matching import FlowMatchingProcessor
from autocast.processors.fno import FNOProcessor
from autocast.processors.vit import AViTProcessor
from autocast.encoders.permute_concat import PermuteConcat
from autocast.decoders.channels_last import ChannelsLast

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]

processor_name = "flow_matching"  # set to "diffusion" to compare
# processor_name = "diffusion"  # set to "flow_matching" to compare

backbone = TemporalUNetBackbone(
    in_channels=n_channels * n_steps_output,
    out_channels=n_channels * n_steps_output,
    cond_channels=n_channels * n_steps_input,
    mod_features=200,
    hid_channels=(32, 64, 128),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)

if processor_name == "flow_matching":
    processor = FlowMatchingProcessor(
        backbone=backbone,
        schedule=VPSchedule(),  # accepted for API parity, not used internally
        n_steps_output=n_steps_output,
        n_channels_out=n_channels,
        stride=stride,
        flow_ode_steps=4,
    )
else:
    from autocast.processors.diffusion import DiffusionProcessor

    processor = DiffusionProcessor(
        backbone=backbone,
        schedule=VPSchedule(),
        n_steps_output=n_steps_output,
        n_channels_out=n_channels,
    )

encoder = PermuteConcat(with_constants=True)
decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)


processor = AViTProcessor(
    in_channels=(n_channels + n_constant_scalars) * n_steps_input,
    out_channels= n_channels * n_steps_output,
    spatial_resolution=(32, 32),
    hidden_dim=128,
    num_heads=8,
    n_layers=8,
    groups=8
)


model = EncoderProcessorDecoder(
    encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),
    processor=processor,
    train_in_latent_space=False,
    learning_rate=5e-4,
    test_metrics = [MSE(), MAE(), RMSE()],
    strie = stride,
    loss_func=processor.loss_func
)
maybe_watch_model(logger, model, watch)


In [None]:
encoder.encode(batch).shape

In [None]:
model(batch).shape

### Run trainer


In [None]:
#logger.logging.wandb.enabled=False

In [None]:
import lightning as L

device = "mps"  # "cpu"
# device = "cpu"
trainer = L.Trainer(
    max_epochs=50, accelerator=device, log_every_n_steps=10, logger=logger
)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
trainer.save_checkpoint(f"./{simulation_name}_{processor_name}_model.ckpt")

### Run the evaluation


In [None]:
trainer.test(model, datamodule.test_dataloader())

### Example rollout


In [None]:
# A single element is the full trajectory
batch = next(iter(datamodule.rollout_test_dataloader()))

In [None]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

In [None]:
# Run rollout on one trajectory
model.max_rollout_steps = 100
preds, trues = model.rollout(batch, stride=n_steps_output, free_running_only=True, max_rollout_steps=90)

print(preds.shape)
assert trues is not None
print(trues.shape)


In [None]:
from autocast.metrics.spatiotemporal import MSE

assert trues is not None
assert preds.shape == trues.shape
mse = MSE()
mse_error_spatial = mse(preds, trues)
mse_error = mse(preds, trues)
print("MSE spatial has shape (B,T,C):", mse_error_spatial.shape)
print("MSE overall is a single scalar:", mse_error.shape)

In [None]:
from IPython.display import HTML

from autocast.utils import plot_spatiotemporal_video

batch_idx = 0
if simulation_name == "advection_diffusion_multichannel":
    channel_names = ["vorticity", "velocity_x", "velocity_y", "streamfunction"]
elif simulation_name == "advection_diffusion":
    channel_names = ["vorticity"]
elif simulation_name == "reaction_diffusion":
    channel_names = ["U", "V"]
else:
    channel_names = None

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    batch_idx=batch_idx,
    save_path=f"{simulation_name}_{batch_idx:02d}.mp4",
    colorbar_mode="column",
    channel_names=channel_names,
)
HTML(anim.to_jshtml())

In [None]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(preds[0,:,0,0,0].detach().numpy())
plt.plot(trues[0,:,0,0,0].detach().numpy())