In [1]:

import torch
import matplotlib.pyplot as plt
from models.autoencoder import KoopmanAutoencoder
from models.utils import load_checkpoint, load_config, get_dataset_class_and_kwargs, load_datasets
from models.dataloader import create_dataloaders
from models.metrics import Metric



# Configuration


In [2]:
CONFIG_PATH = "path/to/your/config.yaml"  # Update this path
CKPT_PATH = None  # Set to specific checkpoint if needed, otherwise loads best/final

ROLL_OUT_STEPS = 40  # Number of future time steps to predict
VISUALIZE = True     # Toggle to show visualizations
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
config = load_config(CONFIG_PATH)
config['training']['use_checkpoint'] = False  # Don't checkpoint during testing

dataset_class, dataset_kwargs = get_dataset_class_and_kwargs(config)
_, _, test_dataset = load_datasets(config, dataset_class, dataset_kwargs)
_, _, test_loader = create_dataloaders(
    train_dataset=None,
    val_dataset=None,
    test_dataset=test_dataset,
    config=config
)

#  Load Model 


In [None]:
model = KoopmanAutoencoder(
    input_frames=config["data"]["input_sequence_length"],
    input_channels=config["model"]["input_channels"],
    height=config["model"]["height"],
    width=config["model"]["width"],
    latent_dim=config["model"]["latent_dim"],
    hidden_dims=config["model"]["hidden_dims"],
    use_checkpoint=False,
    **config["model"]["conv_kwargs"]
).to(DEVICE)

if CKPT_PATH is not None:
    print(f"Loading from checkpoint: {CKPT_PATH}")
    model, _, _, _ = load_checkpoint(CKPT_PATH, model=model, optimizer=None)

model.eval()

#  Run Rollout


In [None]:
def run_long_rollout(model, input_seq, rollout_steps):
    input_seq = input_seq.unsqueeze(0).to(DEVICE)  # [1, T, C, H, W]
    preds = [input_seq[:, i] for i in range(input_seq.size(1))]  # initial context

    with torch.no_grad():
        for _ in range(rollout_steps):
            context = torch.stack(preds[-config["data"]["input_sequence_length"]:], dim=1)  # sliding window
            pred = model.predict(context)  # [B, C, H, W]
            preds.append(pred)

    return torch.stack(preds, dim=1).squeeze(0).cpu()  # [T+rollout_steps, C, H, W]

#  Visualize


In [None]:
sample = test_dataset[0]  # [T, C, H, W]
input_seq = sample[:config["data"]["input_sequence_length"]]
ground_truth = sample[:config["data"]["input_sequence_length"] + ROLL_OUT_STEPS]

predicted_seq = run_long_rollout(model, input_seq, ROLL_OUT_STEPS)

#  Plot Results 


In [None]:
def plot_rollout(gt, pred, variable_idx=0, frame_stride=5):
    plt.figure(figsize=(15, 6))
    num_plots = min(gt.size(0), pred.size(0), 10)
    for i in range(num_plots):
        plt.subplot(2, num_plots, i+1)
        plt.imshow(gt[i, variable_idx], cmap='viridis')
        plt.axis('off')
        if i == 0:
            plt.title("Ground Truth")

        plt.subplot(2, num_plots, num_plots+i+1)
        plt.imshow(pred[i, variable_idx], cmap='viridis')
        plt.axis('off')
        if i == 0:
            plt.title("Prediction")

    plt.suptitle("Koopman AE: Ground Truth vs Prediction (Long Rollout)")
    plt.tight_layout()
    plt.show()

if VISUALIZE:
    plot_rollout(ground_truth, predicted_seq)


# Metrics

In [None]:
metric = Metric(
    mode=config["metric"]["type"],
    variable_mode=config["metric"]["variable_mode"]
)

loss = metric(predicted_seq[:ground_truth.shape[0]], ground_truth)
print(f"\nLong Rollout {config['metric']['type']} Metric: {loss:.4f}")