In [None]:
from pathlib import Path

from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf

CONFIG_DIR = (Path.cwd() / "../configs").resolve()

# Example overrides for quick testing:
overrides = [
    # "trainer.max_epochs=1",
    # "trainer.accelerator=cpu",
    # "data.split.n_train=2",
    # "data.split.n_valid=1",
    # "data.split.n_test=1",
]

with initialize_config_dir(version_base=None, config_dir=str(CONFIG_DIR)):
    cfg = compose(config_name="autoencoder", overrides=overrides)

print(OmegaConf.to_yaml(cfg))

In [None]:
from autocast.logging import create_wandb_logger, maybe_watch_model

cfg.logging.wandb.enabled = True
cfg.logging.wandb.project = "autocast-notebooks"
cfg.logging.wandb.name = "01_encoder_decoder"
resolved_cfg = OmegaConf.to_container(cfg, resolve=True)
wandb_logger, wandb_watch = create_wandb_logger(
    cfg.logging,
    experiment_name="01_encoder_decoder",
    job_type="notebook",
    config={"hydra": resolved_cfg} if resolved_cfg is not None else None,
)

In [None]:
from autocast.train.autoencoder import build_datamodule

datamodule = build_datamodule(cfg.data)

In [None]:

from autocast.types.batch import Batch

batch: Batch = next(iter(datamodule.train_dataloader()))
train_inputs = batch.input_fields
train_outputs = batch.output_fields
train_inputs.shape, train_outputs.shape


In [None]:
import torch

torch.allclose(train_inputs, train_outputs)

In [None]:
channel_count = train_inputs.shape[-1]
cfg.model.encoder.in_channels = channel_count
cfg.model.decoder.out_channels = channel_count
print(f"Detected {channel_count} channels; config updated to match input distribution.")

In [None]:
from hydra.utils import instantiate

from autocast.train.autoencoder import build_model

model = build_model(cfg.model)
maybe_watch_model(wandb_logger, model, wandb_watch)
trainer = instantiate(
    cfg.trainer,
    logger=wandb_logger,
    enable_checkpointing=False,
    default_root_dir=".",
)
model

In [None]:
trainer.fit(
    model,
    train_dataloaders=datamodule.train_dataloader(),
    val_dataloaders=datamodule.val_dataloader(),
)

In [None]:
from pathlib import Path

checkpoint_path = Path("notebook_autoencoder.ckpt")
trainer.save_checkpoint(checkpoint_path)
checkpoint_path.resolve()

In [None]:
import matplotlib.pyplot as plt

device = "cpu"
num_examples = 2

for idx, batch in enumerate(datamodule.test_dataloader()):
    inputs = batch.input_fields.to(device)
    outputs, latents = model.forward_with_latent(batch)
    print("Input shape:", inputs.shape)
    print("Output shape:", outputs.shape)
    print("Latent shape:", latents.shape)

    input_frame = inputs[0, 0, :, :, 0].detach().cpu().numpy()
    output_frame = outputs[0, 0, :, :, 0].detach().cpu().numpy()
    latent_frame = latents[0, 0, :, :, 0].detach().cpu().numpy()

    fig, axs = plt.subplots(1, 4, figsize=(10, 4))
    axs[0].imshow(input_frame, cmap="viridis")
    axs[0].set_title("Input")
    axs[1].imshow(output_frame, cmap="viridis")
    axs[1].set_title("Reconstruction")
    axs[2].imshow(output_frame - input_frame, cmap="viridis")
    axs[2].set_title("Difference")
    axs[3].imshow(latent_frame, cmap="viridis")
    axs[3].set_title("Latent")
    plt.tight_layout()
    plt.show()

    if idx + 1 >= num_examples:
        break