# BOUT - Diffusion test

In [None]:
import torch
import torch.nn as nn

from autocast.data.datamodule import SpatioTemporalDataModule
from autocast.data.dataset import BOUTDataset
from autocast.types import EncodedBatch

data_path="data/bout_split"


In [None]:
n_steps_input = 1
n_steps_output = 4
datamodule = SpatioTemporalDataModule(
    data_path=data_path,
    dataset_cls=BOUTDataset,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=1,
    batch_size=1,
    dtype=torch.float32,
    ftype="torch",
    verbose=True,
)

train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

batch.input_fields.shape, batch.output_fields.shape, batch.constant_scalars.shape


In [None]:
import torch
from azula.noise import CosineSchedule

from autocast.decoders.identity import IdentityDecoder
from autocast.encoders.identity import IdentityEncoder
from autocast.models.encoder_decoder import EncoderDecoder
from autocast.models.encoder_processor_decoder import EPDTrainProcessor
from autocast.nn.unet import TemporalUNetBackbone
from autocast.processors.diffusion import DiffusionProcessor


In [None]:
batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
print("Number of channels:", n_channels)
# Create schedule
schedule = CosineSchedule()
mod_features = 128
backbone = TemporalUNetBackbone(
    in_channels=n_channels,        # Just 1 (per timestep)
    out_channels=n_channels,       # Just 1 (per timestep)
    cond_channels=n_channels,      # Just 1 (per timestep)
    mod_features=mod_features,
    hid_channels=(32, 64, 128),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
    temporal_method="tcn"
)

In [None]:
total_timesteps = 112
stride =4

# Calculate max rollout steps needed

max_rollout_steps = 5
# Update your processor
processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-3,
    n_steps_output=n_steps_output,  # Still 4 (window size)
)
encoder = IdentityEncoder()
decoder = IdentityDecoder()


model = EPDTrainProcessor(
    encoder_decoder=EncoderDecoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
    max_rollout_steps=max_rollout_steps,
)

In [None]:
import lightning as L
from lightning.pytorch.loggers import WandbLogger
import wandb
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", device)
# device = "cpu"


In [None]:
import lightning as L

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", device)
# device = "cpu"


In [None]:
wandb_logger = WandbLogger(
    project="bout-diffusion",
    name="test-run-1",
)

trainer = L.Trainer(max_epochs=3, accelerator=device, log_every_n_steps=10, logger=wandb_logger
)


In [None]:
'''
# Load WITH the components
model = EncoderProcessorDecoder.load_from_checkpoint(
    "lightning_logs/version_3/checkpoints/epoch=19-step=18340.ckpt",
    encoder_decoder=EncoderDecoder.from_encoder_decoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
    strict=False
)
'''


In [None]:
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
# trainer.test(model, datamodule.test_dataloader())

In [None]:
batch = next(iter(datamodule.rollout_test_dataloader()))
model.stride = 4
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)
preds, trues = model.rollout(batch, stride=model.stride, free_running_only=True)

In [None]:
from IPython.display import HTML

from autocast.utils import plot_spatiotemporal_video

anim = plot_spatiotemporal_video(
    pred=preds,
    true=trues,
    cmap="plasma",
)
HTML(anim.to_jshtml())