In [None]:
from autoemulate.simulations.advection_diffusion import AdvectionDiffusion as Sim

sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator: Sim, n_train: int = 10, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

In [None]:
from autocast.data.datamodule import SpatioTemporalDataModule

n_steps_input = 1
n_steps_output = 4
datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=16,
)

In [None]:
batch = next(iter(datamodule.train_dataloader()))

batch.input_fields.shape, batch.output_fields.shape

In [None]:
import torch
from azula.noise import CosineSchedule

from autocast.decoders.identity import IdentityDecoder
from autocast.encoders.identity import IdentityEncoder
from autocast.models.encoder_decoder import EncoderDecoder
from autocast.models.encoder_processor_decoder import EPDTrainProcessor
from autocast.nn.unet import TemporalUNetBackbone
from autocast.nn.vit import TemporalViTBackbone
from autocast.processors.diffusion import DiffusionProcessor



### Setup the backbone

In [None]:
batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
print("Number of channels:", n_channels)
# Create schedule
schedule = CosineSchedule()
mod_features = 128

backbone = TemporalViTBackbone(
    in_channels=n_channels,
    out_channels=n_channels,
    cond_channels=n_channels,
    mod_features=mod_features,
    n_steps_output=n_steps_output,
    n_steps_input=n_steps_input,
    hid_channels=512,                # ViT hidden dimension
    hid_blocks=8,                    # Number of transformer blocks
    attention_heads=8,                     # ViT attention heads
    patch_size=5,                    # Spatial patch size
    spatial=2,
    temporal_method="attention",
)


### Initiate models

In [None]:
total_timesteps = 320

# Calculate max rollout steps needed
max_rollout_steps = total_timesteps - n_steps_input  # 320 - 1 = 319
stride =4
# Update your processor
processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-4,
    n_steps_output=n_steps_output,  # Still 4 (window size)
)
encoder = IdentityEncoder()
decoder = IdentityDecoder()


model = EPDTrainProcessor(
    encoder_decoder=EncoderDecoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
    max_rollout_steps=max_rollout_steps,
)

In [None]:
import lightning as L
from autocast.logging import create_wandb_logger, maybe_watch_model

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
    )
print("Using device:", device)
# device = "cpu"

In [None]:
from autocast.logging.wandb import create_notebook_logger

wandb_logger, wandb_watch = create_notebook_logger(
    project="autocast-notebooks",
    name="03_diffusion_reaction_test",
    tags=["notebook", "03-diffusion-reaction-test"]
)
trainer = L.Trainer(max_epochs=3, accelerator=device, log_every_n_steps=10, logger=wandb_logger)
maybe_watch_model(wandb_logger, model, wandb_watch)

In [None]:
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
trainer.test(model, datamodule.test_dataloader())

In [None]:
model.stride = stride

In [None]:
# After training, do full rollout
model.max_rollout_steps = 5
model.processor.max_rollout_steps = 5
stride = 4
batch = next(iter(datamodule.rollout_test_dataloader()))
preds, trues = model.rollout(batch, stride=stride, free_running_only=True)

print(f"Predictions shape: {preds.shape}")  # Should be [B, 319, 4, 50, 50, 1]



In [None]:
preds.mean(), trues.mean() # type: ignore
preds.std(), trues.std() # type: ignore

In [None]:
from IPython.display import HTML

from autocast.utils import plot_spatiotemporal_video

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    cmap="plasma",
)
HTML(anim.to_jshtml())