# BOUT - Diffusion test

In [None]:
import torch
import torch.nn as nn
from auto_cast.data.dataset import BOUTDataset
from auto_cast.data.datamodule import SpatioTemporalDataModule
from azula.noise import CosineSchedule
from auto_cast.types import EncodedBatch
data_path="data/bout_split"


# ============================================================================
# 2. Load Data
# ============================================================================


In [None]:

n_steps_input = 1
n_steps_output = 4
datamodule = SpatioTemporalDataModule(
    data_path=data_path,
    dataset_cls=BOUTDataset,
    n_steps_input=n_steps_input,
    n_steps_output=n_steps_output,
    stride=1,
    batch_size=1,
    dtype=torch.float32,
    ftype="torch",
    verbose=True,
)

train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

batch.input_fields.shape, batch.output_fields.shape, batch.constant_scalars.shape


In [None]:
import torch
from einops import rearrange

from auto_cast.encoders.base import Encoder
from auto_cast.types import Batch, Tensor, TensorBCWH


class IdentityEncoder(Encoder):
    """Permute and concatenate Encoder."""

    def __init__(self) -> None:
        super().__init__()
    def forward(self, batch: Batch) -> Tensor:
        return batch.input_fields

    def encode(self, batch: Batch) -> TensorBCWH:
        return self.forward(batch)
    
from einops import rearrange

from auto_cast.decoders.base import Decoder
from auto_cast.types import TensorBCTSPlus, TensorBMStarL, TensorBTSPlusC


class IdentityDecoder(Decoder):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: TensorBCTSPlus) -> TensorBTSPlusC:
        return x

    def decode(self, z: TensorBTSPlusC) -> TensorBTSPlusC:
        return self.forward(z)


# Wrap Azula UNET

In [None]:
from matplotlib.pylab import cond
import torch
import torch.nn as nn
from azula.nn.unet import UNet
from azula.nn.embedding import SineEncoding

class TemporalUNetBackbone(nn.Module):
    """Azula UNet with proper time embedding."""
    
    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        cond_channels: int = 1,
        mod_features: int = 256,
        hid_channels: tuple = (32, 64, 128),
        hid_blocks: tuple = (2, 2, 2),
        spatial: int = 2,
        periodic: bool = False,
    ):
        super().__init__()
        
        # Time embedding
        self.time_embedding = nn.Sequential(
            SineEncoding(mod_features),
            nn.Linear(mod_features, mod_features),
            nn.SiLU(),
            nn.Linear(mod_features, mod_features),
        )
        
        self.unet = UNet(
            in_channels=in_channels + cond_channels,
            out_channels=out_channels,
            cond_channels=0,
            mod_features=mod_features,
            hid_channels=hid_channels,
            hid_blocks=hid_blocks,
            kernel_size=3,
            stride=2,
            spatial=spatial,
            periodic=periodic,
        )

    def forward(self, x_out: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x_out: Noisy data (B, T, C, H, W) - channels first from Azula
            t: Time steps (B,)
            cond: Conditioning input (B, T_cond, C, H, W) - channels first
        Returns:
            Denoised output (B, T, C, H, W)
        """
        B, T, W, H, C = x_out.shape
        _, T_cond, W_cond, H_cond , C_cond = cond.shape
        assert W == W_cond and H == H_cond

        # Embed time (once per batch)
        t_emb = self.time_embedding(t)  # (B, mod_features)
        mod_for_unet = t_emb
        t_emb = rearrange(t_emb, "b m -> b  1 1 1 m")
        t_emb = t_emb.expand(B, T_cond, W, H, -1)  # (B, mod_features, H, W)

        # Concatenate along channel dimension
        x_cond = torch.cat([cond, t_emb], dim=-1)  # (B, T, C+C_cond, H, W)
        
        x_cond = rearrange(x_cond, "b t w h c -> b (t c) w h")
        # Process through UNet
        out_flat = self.unet(x_cond, mod=mod_for_unet)
        # Reshape back to (B, T, C, H, W)
        return out_flat.reshape(B, T, W, H, C)


# ============================================================================
# 3. Create DiffusionProcessor
# ============================================================================

In [None]:
batch.output_fields.shape

In [None]:
from auto_cast.decoders.channels_last import ChannelsLast
from auto_cast.encoders.permute_concat import PermuteConcat
from auto_cast.models.encoder_decoder import EncoderDecoder
from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder
from auto_cast.processors.diffusion import DiffusionProcessor
from azula.noise import CosineSchedule

batch = next(iter(datamodule.train_dataloader()))
n_channels = batch.input_fields.shape[-1]
# Create schedule
schedule = CosineSchedule()
mod_features = 64
backbone = TemporalUNetBackbone(
    in_channels=(n_channels+mod_features)*n_steps_input,          # 1
    out_channels=n_channels* n_steps_output,         # 1
    cond_channels=0,        # 1
    mod_features=mod_features,
    hid_channels=(16, 32, 64),
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)


processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-4,
    n_steps_output=n_steps_output,  # 4
    stride=1,
    max_rollout_steps=10,
    teacher_forcing_ratio=0.0,
)
encoder = IdentityEncoder()
decoder = IdentityDecoder()

model = EncoderProcessorDecoder.from_encoder_processor_decoder(
    encoder_decoder=EncoderDecoder.from_encoder_decoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
)

# ============================================================================
# 4. Test Forward Pass
# ============================================================================

In [None]:
import lightning as L

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", device)
# device = "cpu"
trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10, precision="16-mixed")


In [None]:
trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())

In [None]:
'''
# Load WITH the components
model = EncoderProcessorDecoder.load_from_checkpoint(
    "lightning_logs/version_3/checkpoints/epoch=19-step=18340.ckpt",
    encoder_decoder=EncoderDecoder.from_encoder_decoder(
        encoder=encoder, decoder=decoder
    ),
    processor=processor,
    strict=False
)
'''


In [None]:
trainer.test(model, datamodule.test_dataloader())

In [None]:
batch = next(iter(datamodule.rollout_test_dataloader()))
# First n_steps_input are inputs
print(batch.input_fields.shape)
# Remaining n_steps_output are outputs
print(batch.output_fields.shape)
preds, trues = model.rollout(batch, free_running_only=True)

In [None]:
# ============================================================================
# Create Side-by-Side Animated GIF from Rollout
# ============================================================================

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter

# preds and trues are already available from your rollout!
print(f"Predictions shape: {preds.shape}")  # [B, R, T, H, W, C]
print(f"Ground truth shape: {trues.shape if trues is not None else 'None'}")

# Select sample to visualize
sample_idx = 0

# Extract predictions and ground truth
pred_rollout = preds[sample_idx].cpu().detach().numpy()  # [R, T, H, W, C]
if trues is not None:
    true_rollout = trues[sample_idx].cpu().detach().numpy()  # [R, T, H, W, C]
else:
    true_rollout = np.zeros_like(pred_rollout)

# Get dimensions
num_rollout_windows, timesteps_per_window, H, W, C = pred_rollout.shape
print(f"\nRollout windows: {num_rollout_windows}")
print(f"Timesteps per window: {timesteps_per_window}")
print(f"Spatial size: {H}x{W}")
print(f"Channels: {C}")

# Create sequence: take first timestep from each rollout window
pred_sequence = pred_rollout[:, 0, ...]  # [R, H, W, C]
true_sequence = true_rollout[:, 0, ...]  # [R, H, W, C]

# Prepend initial condition
initial_frame = batch.input_fields[sample_idx, -1].cpu().numpy()  # [H, W, C]
pred_sequence = np.concatenate([initial_frame[None, ...], pred_sequence], axis=0)
true_sequence = np.concatenate([initial_frame[None, ...], true_sequence], axis=0)

# Squeeze channel dimension if C=1
if C == 1:
    pred_sequence = pred_sequence.squeeze(-1)  # [Frames, H, W]
    true_sequence = true_sequence.squeeze(-1)

num_frames = pred_sequence.shape[0]
print(f"Total frames in sequence: {num_frames}")

# Normalize for visualization
vmin = min(pred_sequence.min(), true_sequence.min())
vmax = max(pred_sequence.max(), true_sequence.max())
print(f"Data range: [{vmin:.4f}, {vmax:.4f}]")

def normalize(x):
    return (x - vmin) / (vmax - vmin) if vmax > vmin else x

# Create figure with side-by-side plots
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Initialize images
im_pred = axes[0].imshow(normalize(pred_sequence[0]), cmap='viridis', vmin=0, vmax=1)
axes[0].set_title(f'Prediction (Frame 0/{num_frames-1})')
axes[0].axis('off')
plt.colorbar(im_pred, ax=axes[0], fraction=0.046, pad=0.04)

im_true = axes[1].imshow(normalize(true_sequence[0]), cmap='viridis', vmin=0, vmax=1)
axes[1].set_title(f'Ground Truth (Frame 0/{num_frames-1})')
axes[1].axis('off')
plt.colorbar(im_true, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout()

# Animation update function
def update(frame_idx):
    im_pred.set_array(normalize(pred_sequence[frame_idx]))
    axes[0].set_title(f'Prediction (Frame {frame_idx}/{num_frames-1})')
    
    im_true.set_array(normalize(true_sequence[frame_idx]))
    axes[1].set_title(f'Ground Truth (Frame {frame_idx}/{num_frames-1})')
    
    return [im_pred, im_true]

# Create animation
print("\nCreating animation...")
anim = animation.FuncAnimation(
    fig,
    update,
    frames=num_frames,
    interval=100,  # 100ms per frame = 10 fps
    blit=True,
    repeat=True
)

# Save as GIF
print("Saving GIF...")
writer = PillowWriter(fps=10)
anim.save('rollout_comparison.gif', writer=writer)
print("âœ“ Saved to: rollout_comparison.gif")

# Display in notebook
from IPython.display import Image, display
display(Image(filename='rollout_comparison.gif'))
