# Sample Autoencoder

In this notebook a sample autoencoder model is presented, showcasing the functionalities from the source utility files, upon which all model implementations can be built.

## Add project root to path

In [None]:
from pathlib import Path

src_root = Path.cwd().parent  # notebooks -> ml-sandbox

# Add source root to sys.path
import sys

if str(src_root) not in sys.path:
    sys.path.append(str(src_root))

## Import necessary modules

In [None]:
from src import (
    MnistDataset,
    Regressor,
    Autoencoder,
)

# Machine learning
import torch
from torch import nn as nn

## Define training hyperparameters

In [None]:
hparams = dict(
    # Training hyperparameters
    batch_size=64,
    num_epochs=10,
    # Model hyperparameters
    learning_rate=1e-3,
    regularization_weight=0.0,
)

## Define the model

In [None]:
mnist = MnistDataset()
dataloaders = mnist.get_dataloaders(
    train_split=0.6, val_split=0.2, test_split=0.2, batch_size=hparams["batch_size"]
)

model = Autoencoder(dims=[28 * 28, 128, 64, 32], sigma=nn.LeakyReLU)

optimizer = torch.optim.Adam(model.parameters(), lr=hparams["learning_rate"])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=10,
)

trainer = Regressor(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    scheduler_metric=True,
    loss_function="MSE",
).initialize_regularization(method="L2", weight=hparams["regularization_weight"])

## Train the model

In [None]:
results = trainer.train(
    train_dataloader=dataloaders["train"],
    validation_dataloader=dataloaders["val"],
    epochs=hparams["num_epochs"],
)

## Print training results

In [None]:
print("=" * 20 + " Training Results " + "=" * 20)

print(f"Final training loss: {results.train_loss:.4f}")
print(f"Final validation loss: {results.validation_loss:.4f}")

## Show sample reconstruction images

In [None]:
import random
from src.input_data.structure.base import ManagedDataset
from src.input_data.structure.plots import plot_samples


def show_random_samples(
    dataset: ManagedDataset, model: nn.Module, num_samples: int = 8, figsize=(12, 8)
) -> None:
    """Display random samples from the dataset."""
    if len(dataset) == 0:
        print("Dataset is empty!")
        return

    # Set model to evaluation mode and move to a proper device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Select random indices
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    with torch.no_grad():
        for idx in indices:
            image, target_idx = dataset[idx]

            image = image.to(device)

            image_reconstructed = model(image.unsqueeze(0)).squeeze(0)
            image_list = [image, image_reconstructed]
            label_list = ["Original", "Reconstructed"]

            plot_samples(
                image_list,
                labels=label_list,
                suptitle=f"Random samples of the {dataset.dataset_info.name} dataset",
            )


show_random_samples(mnist, model, num_samples=5)