## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [None]:
import pickle
from pathlib import Path

from autoemulate.simulations.advection_diffusion import AdvectionDiffusion
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion

from autocast.data.advection_diffusion import (
    AdvectionDiffusion as AdvectionDiffusionMultichannel,
)
from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule
from autocast.metrics.spatiotemporal import MAE, MSE, RMSE

THE_WELL = False
n_steps_input = 1
n_steps_output = 4
stride = 1
rollout_stride = n_steps_output


In [None]:
#n_steps_input = 4
#n_steps_output = 1
#original TheWell paper


### Read combined data into datamodule

In [2]:

if not THE_WELL:
    simulation_name = "reaction_diffusion"
    # simulation_name = "advection_diffusion"
    # simulation_name = "advection_diffusion_multichannel"

    if simulation_name == "advection_diffusion_multichannel":
        # Override to use multichannel version
        Sim = AdvectionDiffusionMultichannel
    if simulation_name == "reaction_diffusion":
        Sim = ReactionDiffusion
    if simulation_name == "advection_diffusion":
        Sim = AdvectionDiffusion

    sim = Sim(return_timeseries=True, log_level="error")

    def generate_split(
        simulator, n_train: int = 200, n_valid: int = 20, n_test: int = 20
    ):
        """Generate training, validation, and test splits from the simulator."""
        train = simulator.forward_samples_spatiotemporal(n_train)
        valid = simulator.forward_samples_spatiotemporal(n_valid)
        test = simulator.forward_samples_spatiotemporal(n_test)
        return {"train": train, "valid": valid, "test": test}

    # Cache file path
    cache_file = Path(f"{simulation_name}_cache.pkl")

    # Load from cache if it exists, otherwise generate and save
    if cache_file.exists():
        print(f"Loading cached simulation data from {cache_file}")
        with open(cache_file, "rb") as f:
            combined_data = pickle.load(f)
            for key in ["data", "constant_scalars", "constant_fields"]:
                combined_data["test"][key] = (
                    combined_data["test"][key][:5]
                    if combined_data["test"][key] is not None
                    else None
                )
    else:
        print("Generating simulation data...")
        combined_data = generate_split(sim)
        print(f"Saving simulation data to {cache_file}")
        with open(cache_file, "wb") as f:
            pickle.dump(combined_data, f)

    datamodule = SpatioTemporalDataModule(
        data=combined_data,
        data_path = None,
        #data_path="../datasets/reaction_diffusion",
        n_steps_input=n_steps_input,
        n_steps_output=n_steps_output,
        stride=n_steps_output,
        batch_size=16,
    )
else:
    simulation_name = "turbulent_radiative_layer_2D"
    datamodule = TheWellDataModule(
        well_base_path="../../autocast/datasets/",
        well_dataset_name=simulation_name,
        n_steps_input=n_steps_input,
        n_steps_output=n_steps_output,
        min_dt_stride=1,
        use_normalization=True,
    )


Loading cached simulation data from reaction_diffusion_cache.pkl


In [3]:
combined_data["train"]["data"].shape

torch.Size([8, 100, 32, 32, 2])

### Set-up logging

In [4]:

from autocast.logging import create_wandb_logger, maybe_watch_model
from autocast.logging.wandb import create_notebook_logger

logger, watch = create_notebook_logger(
    project="autocast-notebooks",
    name=f"00_01_exploration_{simulation_name}",
    tags=["notebook", simulation_name],
)

### Example shape and batch


In [5]:
datamodule.train_dataset[0].input_fields.shape

torch.Size([1, 32, 32, 2])

In [6]:
batch = next(iter(datamodule.train_dataloader()))

batch.input_fields.shape

torch.Size([16, 1, 32, 32, 2])

In [11]:
from azula.noise import VPSchedule

from autocast.decoders.identity import IdentityDecoder
from autocast.encoders.identity import IdentityEncoder
from autocast.models.encoder_decoder import EncoderDecoder
from autocast.models.encoder_processor_decoder import EncoderProcessorDecoder
from autocast.nn.unet import TemporalUNetBackbone
from autocast.processors.flow_matching import FlowMatchingProcessor
from autocast.processors.fno import FNOProcessor
from autocast.encoders.permute_concat import PermuteConcat
from autocast.decoders.channels_last import ChannelsLast

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]

processor_name = "flow_matching"  # set to "diffusion" to compare
# processor_name = "diffusion"  # set to "flow_matching" to compare

backbone = TemporalUNetBackbone(
    in_channels=n_channels * n_steps_output,
    out_channels=n_channels * n_steps_output,
    cond_channels=n_channels * n_steps_input,
    mod_features=200,
    hid_channels=(32, 64, 128),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)

if processor_name == "flow_matching":
    processor = FlowMatchingProcessor(
        backbone=backbone,
        schedule=VPSchedule(),  # accepted for API parity, not used internally
        n_steps_output=n_steps_output,
        n_channels_out=n_channels,
        stride=stride,
        flow_ode_steps=4,
    )
else:
    from autocast.processors.diffusion import DiffusionProcessor

    processor = DiffusionProcessor(
        backbone=backbone,
        schedule=VPSchedule(),
        n_steps_output=n_steps_output,
        n_channels_out=n_channels,
    )

encoder = PermuteConcat()
decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)

processor = FNOProcessor(
    in_channels=n_channels * n_steps_input,
    out_channels=n_channels * n_steps_output,
    n_modes = (16,16)
)


model = EncoderProcessorDecoder(
    encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),
    processor=processor,
    train_processor_only=True,
    learning_rate=1e-4,
    test_metrics = [MSE(), MAE(), RMSE()],
    strie = stride
)
maybe_watch_model(logger, model, watch)


In [12]:
model(batch).shape

torch.Size([16, 4, 32, 32, 2])

### Run trainer


In [13]:
#logger.logging.wandb.enabled=False

In [14]:
import lightning as L

device = "mps"  # "cpu"
# device = "cpu"
trainer = L.Trainer(
    max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger
)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())
trainer.save_checkpoint(f"./{simulation_name}_{processor_name}_model.ckpt")

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores

  | Name            | Type             | Params | Mode 
-------------------------------------------------------------
0 | encoder_decoder | EncoderDecoder   | 0      | eval 
1 | processor       | FNOProcessor     | 2.4 M  | train
2 | val_metrics     | MetricCollection | 0      | train
3 | test_metrics    | MetricCollection | 0      | train
-------------------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.645     Total estimated model params size (MB)
59        Modules in train mode
3         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=4` reached.


### Run the evaluation


In [15]:
trainer.test(model, datamodule.test_dataloader())

Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.3945525586605072
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.3945525586605072}]

### Example rollout


In [16]:
# A single element is the full trajectory
batch = next(iter(datamodule.rollout_test_dataloader()))

In [17]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

torch.Size([2, 1, 32, 32, 2])
torch.Size([2, 99, 32, 32, 2])


In [18]:
# Run rollout on one trajectory
model.max_rollout_steps = 20
preds, trues = model.rollout(batch, stride=rollout_stride, free_running_only=True)

print(preds.shape)
assert trues is not None
print(trues.shape)


torch.Size([2, 40, 32, 32, 2])
torch.Size([2, 40, 32, 32, 2])


In [19]:
from autocast.metrics.spatiotemporal import MSE

assert trues is not None
assert preds.shape == trues.shape
mse = MSE()
mse_error_spatial = mse(preds, trues)
mse_error = mse(preds, trues)
print("MSE spatial has shape (B,T,C):", mse_error_spatial.shape)
print("MSE overall is a single scalar:", mse_error.shape)

MSE spatial has shape (B,T,C): torch.Size([])
MSE overall is a single scalar: torch.Size([])


In [20]:
from IPython.display import HTML

from autocast.utils import plot_spatiotemporal_video

batch_idx = 0
if simulation_name == "advection_diffusion_multichannel":
    channel_names = ["vorticity", "velocity_x", "velocity_y", "streamfunction"]
elif simulation_name == "advection_diffusion":
    channel_names = ["vorticity"]
elif simulation_name == "reaction_diffusion":
    channel_names = ["U", "V"]
else:
    channel_names = None

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    batch_idx=batch_idx,
    save_path=f"{simulation_name}_{batch_idx:02d}.mp4",
    colorbar_mode="column",
    channel_names=channel_names,
)
HTML(anim.to_jshtml())

Video saved to reaction_diffusion_00.mp4
