In [None]:
from autoemulate.simulations.advection_diffusion import AdvectionDiffusion as Sim

sim = Sim(return_timeseries=True, log_level="error")


def generate_split(simulator: Sim, n_train: int = 10, n_valid: int = 2, n_test: int = 2):
    """Generate training, validation, and test splits from the simulator."""
    train = simulator.forward_samples_spatiotemporal(n_train)
    valid = simulator.forward_samples_spatiotemporal(n_valid)
    test = simulator.forward_samples_spatiotemporal(n_test)
    return {"train": train, "valid": valid, "test": test}


combined_data = generate_split(sim)

In [None]:
from auto_cast.data.datamodule import SpatioTemporalDataModule

n_steps_input = 1
n_steps_output = 4
datamodule = SpatioTemporalDataModule(
    data=combined_data,
    data_path=None,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    batch_size=16,
)

In [None]:
batch = next(iter(datamodule.train_dataloader()))

batch.input_fields.shape, batch.output_fields.shape

In [None]:
import torch
from azula.noise import CosineSchedule

from auto_cast.decoders.identity import IdentityDecoder
from auto_cast.encoders.identity import IdentityEncoder
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EPDTrainProcessor
from auto_cast.nn.unet import SimpleUNet, TemporalUNetBackbone
from auto_cast.processors.diffusion import DiffusionProcessor


### Setup the backbone

In [None]:
batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
print("Number of channels:", n_channels)
# Create schedule
schedule = CosineSchedule()
mod_features = 128
backbone = TemporalUNetBackbone(
    in_channels=n_channels * n_steps_output,   # 4 (just the noisy output)
    out_channels=n_channels * n_steps_output,  # 4
    cond_channels=n_channels * n_steps_input,  # 1 (previous timesteps)
    mod_features=mod_features,
)

# backbone = SimpleUNet(
#     in_channels=n_channels * n_steps_output,   # 4 (just the noisy output)
#     out_channels=n_channels * n_steps_output,  # 4
#     cond_channels=n_channels * n_steps_input,  # 1 (previous timesteps)
#     mod_features=mod_features,
# )


### Initiate models

In [None]:
total_timesteps = 320

# Calculate max rollout steps needed
max_rollout_steps = total_timesteps - n_steps_input  # 320 - 1 = 319
stride =4
# Update your processor
processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-4,
    n_steps_output=n_steps_output,  # Still 4 (window size)
    stride=stride,
    max_rollout_steps=max_rollout_steps,  # 319 steps
    teacher_forcing_ratio=0.0,
)
encoder = IdentityEncoder()
decoder = IdentityDecoder()


model = EPDTrainProcessor(
    encoder_decoder=EncoderDecoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
    max_rollout_steps=max_rollout_steps,
)

In [None]:
import lightning as L
from lightning.pytorch.loggers import WandbLogger
import wandb
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", device)
# device = "cpu"


In [None]:
wandb_logger = WandbLogger(
    project="bout-diffusion",
    name="test-run-1",
)

trainer = L.Trainer(max_epochs=3, accelerator=device, log_every_n_steps=10, logger=wandb_logger
)


In [None]:
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
trainer.test(model, datamodule.test_dataloader())

In [None]:
model.stride = stride

In [None]:
# After training, do full rollout
model.max_rollout_steps = 80
model.processor.max_rollout_steps = 80

batch = next(iter(datamodule.rollout_test_dataloader()))
preds, trues = model.rollout(batch, free_running_only=True)

print(f"Predictions shape: {preds.shape}")  # Should be [B, 319, 4, 50, 50, 1]



In [None]:
preds.mean(), trues.mean()
preds.std(), trues.std()

In [None]:
from auto_cast.utils import plot_spatiotemporal_video
from IPython.display import HTML

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    cmap="plasma",
)
HTML(anim.to_jshtml())