## AutoCast processor training example

This notebook demonstrates training a processor directly on encoded data.

### Example dataaset

We use the `ReactionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.


In [None]:
from autocast.data.encoded_dataset import MiniWellDataModule
from autocast.metrics.spatiotemporal import MAE, MSE, RMSE, VRMSE

THE_WELL = False
simulation_name = "rayleigh_benard"
n_steps_input = 1
n_steps_output = 4
stride = 1
rollout_stride = 4


base_path = (
    f"../datasets/{simulation_name}/1e3z5x2c_{simulation_name}_dcae_f32c64_large/cache/{simulation_name}"
)
datamodule = MiniWellDataModule(
    data_path=base_path,
    n_steps_input=n_steps_input, n_steps_output=n_steps_output, stride=stride
)

### Set-up logging


In [None]:
from autocast.logging import maybe_watch_model
from autocast.logging.wandb import create_notebook_logger

logger, watch = create_notebook_logger(
    project="autocast-notebooks",
    name=f"06_processor_{simulation_name}",
    tags=["notebook", simulation_name],
    enabled=False,
)

In [None]:
batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.encoded_inputs.shape[-1]
w, h = batch.encoded_inputs.shape[2:4]

### Example shape and batch


In [None]:
datamodule.train_dataset[0].encoded_inputs.shape

In [None]:
batch = next(iter(datamodule.train_dataloader()))

batch.encoded_inputs.shape

In [None]:

from azula.noise import VPSchedule

from autocast.models.processor import ProcessorModel
from autocast.nn.unet import TemporalUNetBackbone
from autocast.processors.flow_matching import FlowMatchingProcessor

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.encoded_inputs.shape[-1]

# processor_name = "flow_matching"  # set to "diffusion" to compare
processor_name = "diffusion"  # set to "flow_matching" to compare
n_latent_in = batch.encoded_inputs.shape[-1]
n_latent_out = batch.encoded_output_fields.shape[-1]
backbone = TemporalUNetBackbone(
    in_channels=n_latent_out,
    out_channels=n_latent_out,
    cond_channels=n_latent_in,
    n_steps_output=n_steps_output,
    n_steps_input=n_steps_input,
    mod_features=200,
    hid_channels=(32, 64, 128),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)

if processor_name == "flow_matching":
    processor = FlowMatchingProcessor(
        backbone=backbone,
        schedule=VPSchedule(),  # accepted for API parity, not used internally
        n_steps_output=n_steps_output,
        n_channels_out=n_latent_out,
        stride=stride,
        flow_ode_steps=4,
    )
else:
    from autocast.processors.diffusion import DiffusionProcessor

    processor = DiffusionProcessor(
        backbone=backbone,
        schedule=VPSchedule(),
        n_steps_output=n_steps_output,
        n_channels_out=n_latent_out,
    )

model = ProcessorModel(
    processor=processor,
    learning_rate=5e-4,
    test_metrics=[VRMSE(), MSE(), MAE(), RMSE()],
)
maybe_watch_model(logger, model, watch)

In [None]:
model(batch.encoded_inputs).shape

### Run trainer


In [None]:
import lightning as L

# device = "mps"  # "cpu"
device = "cuda"  # "cpu"
trainer = L.Trainer(max_epochs=2, accelerator=device, logger=logger)
trainer.fit(model, datamodule)
trainer.save_checkpoint(f"./{simulation_name}_{processor_name}_model.ckpt")

### Run the evaluation


In [None]:
trainer.test(model, datamodule)