In [None]:
# from autoemulate.simulations.advection_diffusion import AdvectionDiffusion
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion as Sim

sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator: Sim, n_train: int = 4, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

In [None]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=1,
    n_steps_output=0,
    batch_size=16,
    autoencoder_mode=True,
)

In [None]:
batch = next(iter(datamodule.train_dataloader()))


In [None]:
# Check input field shape: batch of single frames with two channels
batch.input_fields.shape

In [None]:
import torch

torch.allclose(batch.input_fields, batch.output_fields)


In [None]:
from auto_cast.decoders.dc import DCDecoder
from auto_cast.encoders.dc import DCEncoder
from auto_cast.models.ae import AE

channels = batch.input_fields.shape[-1]

encoder = DCEncoder(
    in_channels=channels,
    out_channels=16,
    hid_channels=(32, 64),
    spatial=2,
    hid_blocks=(2, 2),
    pixel_shuffle=False,
)

decoder = DCDecoder(
    in_channels=16,
    out_channels=channels,
    hid_channels=(64, 32),
    spatial=2,
    hid_blocks=(2, 2),
    pixel_shuffle=False,
)
model = AE(encoder=encoder, decoder=decoder)

In [None]:
import lightning as L

device = "mps"  # "cpu"
trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
import matplotlib.pyplot as plt

for idx, batch in enumerate(datamodule.test_dataloader()):
    inputs = batch.input_fields.to(device)
    outputs, latents = model.forward_with_latent(batch)
    print("Input shape:", inputs.shape)
    print("Output shape:", outputs.shape)
    print("Latent shape:", latents.shape)
    fig, axs = plt.subplots(1, 4, figsize=(8, 4))
    axs[0].imshow(inputs[0, 0, :, :, 0].cpu().numpy(), cmap="viridis")
    axs[0].set_title("Input")
    axs[1].imshow(outputs[0, 0, :, :, 0].detach().cpu().numpy(), cmap="viridis")
    axs[1].set_title("Reconstruction")
    axs[2].imshow(
        outputs[0, 0, :, :, 0].detach().cpu().numpy()
        - inputs[0, 0, :, :, 0].cpu().numpy(),
        cmap="viridis",
    )
    axs[2].set_title("Difference")
    axs[3].imshow(latents[0, 0, :, :, 0].detach().cpu().numpy(), cmap="viridis")
    axs[3].set_title(f"Latent dim {0}")
    plt.show()
    if idx >= 3:
        break