## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [None]:
import pickle
from pathlib import Path

from autoemulate.simulations.advection_diffusion import AdvectionDiffusion
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion

simulation_name = "reaction_diffusion"

Sim = (
    ReactionDiffusion if simulation_name == "reaction_diffusion" else AdvectionDiffusion
)
sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator, n_train: int = 8, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


# Cache file path
cache_file = Path(f"{simulation_name}_cache.pkl")

# Load from cache if it exists, otherwise generate and save
if cache_file.exists():
    print(f"Loading cached simulation data from {cache_file}")
    with open(cache_file, "rb") as f:
        combined_data = pickle.load(f)
else:
    print("Generating simulation data...")
    combined_data = generate_split(sim)
    print(f"Saving simulation data to {cache_file}")
    with open(cache_file, "wb") as f:
        pickle.dump(combined_data, f)


### Read combined data into datamodule


In [None]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

n_steps_input = 4
n_steps_output = 4
stride = 4
datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=n_steps_output,
    batch_size=16,
)

### Example batch


In [None]:
batch = next(iter(datamodule.train_dataloader()))

# batch

In [None]:
from auto_cast.decoders.channels_last import ChannelsLast
from auto_cast.encoders.permute_concat import PermuteConcat
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
from auto_cast.processors.fno import FNOProcessor

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
processor = FNOProcessor(
    in_channels=n_channels * n_steps_input,
    out_channels=n_channels * n_steps_output,
    n_modes=(16, 16),
    hidden_channels=64,
    stride=n_steps_output,
    max_rollout_steps=100,
)
encoder = PermuteConcat(with_constants=False)
decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)

model = EncoderProcessorDecoder(
    encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),
    processor=processor,
    stride=stride,
)

In [None]:
model(batch).shape

### Run trainer


In [None]:
import lightning as L

device = "mps"  # "cpu"
# device = "cpu"
trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

### Run the evaluation


In [None]:
trainer.test(model, datamodule.test_dataloader())

### Example rollout


In [None]:
# A single element is the full trajectory
batch = next(iter(datamodule.rollout_test_dataloader()))

In [None]:
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)

In [None]:
# Run rollout on one trajectory
preds, trues = model.rollout(batch, free_running_only=True)

In [None]:
from auto_cast.metrics.spatiotemporal import MSE

assert trues is not None
assert preds.shape == trues.shape
mse_error = MSE()(preds, trues, 2)

In [None]:
print(preds.shape)

In [None]:
assert trues is not None
print(trues.shape)


In [None]:
from auto_cast.metrics.spatiotemporal import RMSE

assert trues is not None
rmse_error = RMSE()(preds, trues, 2)

In [None]:
from IPython.display import HTML

from auto_cast.utils import plot_spatiotemporal_video

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
)
HTML(anim.to_jshtml())