## AutoCast encoder-processor-decoder model API Exploration

This notebook aims to explore the end-to-end API.


### Example dataaset

We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.

In [None]:

from autoemulate.simulations.advection_diffusion import AdvectionDiffusion

sim = AdvectionDiffusion(return_timeseries=True, log_level="error")

def generate_split(
    simulator: AdvectionDiffusion, n_train: int = 4, n_valid: int = 2, n_test: int = 2
):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

### Read combined data into datamodule


In [None]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

datamodule = SpatioTemporalDataModule(
    data=combined_data, data_path=None, n_steps_input=4, n_steps_output=1, batch_size=16
)

### Example batch


In [None]:
batch = next(iter(datamodule.train_dataloader()))

# batch

In [None]:
from auto_cast.decoders.channels_last import ChannelsLast
from auto_cast.encoders.permute_concat import PermuteConcat
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
from auto_cast.nn.fno import FNOProcessor

processor = FNOProcessor(
    in_channels=1, out_channels=1, n_modes=(16, 16, 1), hidden_channels=64
)
encoder = PermuteConcat(with_constants=False)
decoder = ChannelsLast()

model = EncoderProcessorDecoder.from_encoder_processor_decoder(
    encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),
    processor=processor,
)

### Run trainer


In [None]:
import lightning as L

device = "mps"  # "cpu"
trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

### Run the evaluation

In [None]:
trainer.test(model, datamodule.test_dataloader())