<a href="https://colab.research.google.com/github/vlad-uve/CAE-MNIST/blob/main/notebooks/CAE_setup.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Create project folder structure
!mkdir -p ../CAE_setup_local/src
!mkdir -p ../CAE_setup_local/notebooks
!mkdir -p ../CAE_setup_local/outputs

# Check structure (install tree if needed)
!apt-get install tree -y
!tree ../CAE_setup_local

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  tree
0 upgraded, 1 newly installed, 0 to remove and 34 not upgraded.
Need to get 47.9 kB of archives.
After this operation, 116 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tree amd64 2.0.2-1 [47.9 kB]
Fetched 47.9 kB in 0s (116 kB/s)
Selecting previously unselected package tree.
(Reading database ... 126102 files and directories currently installed.)
Preparing to unpack .../tree_2.0.2-1_amd64.deb ...
Unpacking tree (2.0.2-1) ...
Setting up tree (2.0.2-1) ...
Processing triggers for man-db (2.10.2-1) ...
[01;34m../CAE_setup_local[0m
├── [01;34mnotebooks[0m
├── [01;34moutputs[0m
└── [01;34msrc[0m

3 directories, 0 files


# Model Classes Setup

In [2]:
# write encoder, decoder and autoencoder classes to src as model.py
%%writefile ../CAE_setup_local/src/model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, n_channels, stride, padding, latent_dim, use_batch_norm, activation_func):
        super(Encoder, self).__init__()

        # store activation function choice
        self.activation_func = activation_func

        # define convolutional layers for encoding
        self.encn1 = nn.Conv2d(1, n_channels[0], 4, stride=stride, padding=padding)
        self.encn2 = nn.Conv2d(n_channels[0], n_channels[1], 4, stride=stride, padding=padding)
        self.encn3 = nn.Conv2d(n_channels[1], n_channels[2], 3, stride=stride, padding=padding)

        # optional batch normalization layers after each conv
        if use_batch_norm:
            self.bn1 = nn.BatchNorm2d(n_channels[0])
            self.bn2 = nn.BatchNorm2d(n_channels[1])
            self.bn3 = nn.BatchNorm2d(n_channels[2])
        else:
            self.bn1 = self.bn2 = self.bn3 = nn.Identity()

        # flatten and fully connected bottleneck layer
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(n_channels[2] * 4 * 4, latent_dim)

    def apply_activation(self, x):
        # safely apply activation function
        if self.activation_func == 'relu':
            return F.relu(x)
        elif self.activation_func == 'leaky_relu':
            return F.leaky_relu(x)
        else:
            raise ValueError(f"Unsupported activation function: {self.activation_func}")

    def forward(self, input_x):
        # pass through encoding layers with batchnorm and activation
        x = self.encn1(input_x)
        x = self.bn1(x)
        x = self.apply_activation(x)

        x = self.encn2(x)
        x = self.bn2(x)
        x = self.apply_activation(x)

        x = self.encn3(x)
        x = self.bn3(x)
        x = self.apply_activation(x)

        # return data encoded in latent space with flattening and fully connected layer
        x = self.flatten(x)
        encoded_x = self.fc1(x)

        return encoded_x


class Decoder(nn.Module):
    def __init__(self, n_channels, stride, padding, latent_dim, use_batch_norm, activation_func):
        super(Decoder, self).__init__()

        # store activation function choice
        self.activation_func = activation_func

        # fully connected + unflatten to prepare for decoding
        self.fc1 = nn.Linear(latent_dim, n_channels[2] * 4 * 4)
        self.unflatten = nn.Unflatten(1, (n_channels[2], 4, 4))

        # optional batch normalization layers after each transposed conv
        if use_batch_norm:
            self.bn1 = nn.BatchNorm2d(n_channels[1])
            self.bn2 = nn.BatchNorm2d(n_channels[0])
        else:
            self.bn1 = self.bn2 = nn.Identity()

        # define transposed conv layers for decoding
        self.decn1 = nn.ConvTranspose2d(n_channels[2], n_channels[1], 3, stride=stride, padding=padding)
        self.decn2 = nn.ConvTranspose2d(n_channels[1], n_channels[0], 4, stride=stride, padding=padding)
        self.decn3 = nn.ConvTranspose2d(n_channels[0], 1, 4, stride=stride, padding=padding)

    def apply_activation(self, x):
        # safely apply activation function
        if self.activation_func == 'relu':
            return F.relu(x)
        elif self.activation_func == 'leaky_relu':
            return F.leaky_relu(x)
        else:
            raise ValueError(f"Unsupported activation function: {self.activation_func}")

    def forward(self, encoded):
        #decoding data from latent space with unflattening of fully connected layer
        x = self.fc1(encoded)
        x = self.unflatten(x)

        # pass through transposed conv layers with batchnorm and activation
        x = self.decn1(x)
        x = self.bn1(x)
        x = self.apply_activation(x)

        x = self.decn2(x)
        x = self.bn2(x)
        x = self.apply_activation(x)

        x = self.decn3(x)
        decoded_x = F.sigmoid(x)

        # return decoded data as an image with pixels of [0, 1] range
        return decoded_x


class AutoEncoder(nn.Module):
    def __init__(self, n_channels, latent_dim, use_batch_norm=False, activation_func='relu'):
        super(AutoEncoder, self).__init__()

        # fixed parameters for all conv layers
        stride = 2
        padding = 1

        # define encoder and decoder blocks
        self.encoder = Encoder(n_channels, stride, padding, latent_dim, use_batch_norm, activation_func)
        self.decoder = Decoder(n_channels, stride, padding, latent_dim, use_batch_norm, activation_func)

    def forward(self, input_x):
        # encode and decode the input data
        encoded_x = self.encoder(input_x)
        decoded_x = self.decoder(encoded_x)

        return decoded_x, encoded_x

Writing ../CAE_setup_local/src/model.py


# Model Training Functions Setup

In [3]:
# write model training, model validation, and full model traning run functions to src as train.py
%%writefile ../CAE_setup_local/src/train.py

import torch
import torch.nn.functional as F

def train_model(model, train_dataloader, optimizer, epoch, device):
    """
    Runs one training epoch for the given model.

    Args:
        model (nn.Module): the autoencoder to train
        train_dataloader (DataLoader): training data loader
        optimizer (torch.optim.Optimizer): optimizer used for training
        epoch (int): current epoch number (used for tracking/logging)

    Returns:
        float: loss value from the last batch of the epoch
    """

    # set model to training mode
    model.train()

    for b_i, (input_x, _) in enumerate(train_dataloader):
        # move batch to device
        input_x = input_x.to(device)

        # clear previous gradients
        optimizer.zero_grad()

        # forward pass: get model output
        decoded_x, encoded_x = model(input_x)

        # compute reconstruction loss between input and output
        loss = F.binary_cross_entropy(decoded_x, input_x)

        # backward pass: compute gradients
        loss.backward()

        # update weights
        optimizer.step()

    # return last batch loss
    return loss.item()


def validate_model(model, validation_dataloader, device):
    """
    Evaluates the model on the validation set using binary cross-entropy loss.

    Args:
        model (nn.Module): trained autoencoder
        validation_dataloader (DataLoader): validation data loader
        device (str): 'cuda' or 'cpu'

    Returns:
        float: average loss over the entire validation set
    """

    # set model to evaluation mode (disables dropout, batchnorm updates etc.)
    model.eval()
    total_loss = 0

    # disable gradient calculation
    with torch.no_grad():
        for input_x, _ in validation_dataloader:
            # move batch to device
            input_x = input_x.to(device)

            # forward pass
            decoded_x, encoded_x = model(input_x)

            # accumulate reconstruction loss for each bath
            total_loss += F.binary_cross_entropy(decoded_x, input_x)

    # compute and return average loss over validation over one epoch
    avg_loss = total_loss / len(validation_dataloader)

    return avg_loss.item()


def run_model_training(model, train_dataloader, validation_dataloader, optimizer, scheduler, num_epoch, device):
    """
    Trains the model across multiple epochs and evaluates on validation set.

    Args:
        model (nn.Module): the autoencoder
        train_dataloader (DataLoader): training data
        validation_dataloader (DataLoader): validation data
        optimizer (Optimizer): optimizer for training
        scheduler (LRScheduler): learning rate scheduler
        num_epoch (int): number of training epochs

    Returns:
        model: trained model
        dict: loss history containing 'train', 'validation', and 'epoch' lists
    """

    # initialize loss tracking dictionary
    loss_history = {'train': [], 'validation': [], 'epoch': []}

    print('\nTRAINING IS STARTED:')

    # run training loop
    for epoch in range(1, num_epoch + 1):
        # train model on training set
        train_loss = train_model(model, train_dataloader, optimizer, epoch, device)

        # evaluate model on validation set
        validation_loss = validate_model(model, validation_dataloader, device)

        # check if scheduler reduces learning rate based on validation loss plateau
        previous_lr = optimizer.param_groups[0]['lr']
        scheduler.step(validation_loss)
        current_lr = optimizer.param_groups[0]['lr']
        if current_lr != previous_lr:
            print(f"LR reduced from {previous_lr:.4f} → {current_lr:.4f}")

        # record losses and epoch number
        loss_history['train'].append(train_loss)
        loss_history['validation'].append(validation_loss)
        loss_history['epoch'].append(epoch)

        # print progress
        print(f"Epoch {epoch:2d} | Train Loss: {train_loss:.4f} | Validation Loss: {validation_loss:.4f}")

    print('\nTRAINING IS FINISHED.')

    return model, loss_history

Writing ../CAE_setup_local/src/train.py


# Data Loaders Setup

In [4]:
# write train and validation dataloader functions to src as data.py

%%writefile ../CAE_setup_local/src/data.py

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_train_dataloader(batch_size=32, data_dir='../data'):
    """
    Loads the MNIST training set and returns a DataLoader.

    Args:
        batch_size (int): size of each training batch
        data_dir (str): path to MNIST data storage

    Returns:
        DataLoader: PyTorch DataLoader for training data
    """
    train_dataset = datasets.MNIST(
        root=data_dir,
        train=True,
        download=True,
        transform=transforms.ToTensor()
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    return train_dataloader


def get_validation_dataloader(batch_size=500, data_dir='../data'):
    """
    Loads the MNIST validation (test) set and returns a DataLoader.

    Args:
        batch_size (int): size of each validation batch
        data_dir (str): path to MNIST data storage

    Returns:
        DataLoader: PyTorch DataLoader for validation data
    """
    validation_dataset = datasets.MNIST(
        root=data_dir,
        train=False,
        transform=transforms.ToTensor()
    )

    validation_dataloader = DataLoader(
        validation_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    return validation_dataloader

Writing ../CAE_setup_local/src/data.py


# Evaluation Functions

In [5]:
# write sampling and experiment reconstruction functions to src as evaluation.py

%%writefile ../CAE_setup_local/src/evaluation.py

import torch

def get_image_samples(validation_dataloader):

  # get a batch of images and labels from the validation set
  images, labels = next(iter(validation_dataloader))

  # select exactly one example for each digit (0-9)
  sample_labels = [i for i in range(10)]
  sample_indices = [torch.where(labels == i)[0][0].item() for i in range(10)]

  sample_images = images[sample_indices]

  return sample_images, sample_labels



def get_experiment_reconstructions(model_list, original_images, device):
    '''
    Run models on input images and return reconstructed outputs.

    Args:
        model_list (list): list of trained models to evaluate
        original_images (torch.Tensor): batch of original input images

    Returns:
        list of torch.Tensor: reconstructed images for each model
    '''
    reconstructions = []
    for model in model_list:
        model.eval()
        with torch.no_grad():
            reconstructed_images, _ = model(original_images.to(device))
            reconstructions.append(reconstructed_images.cpu())
    return reconstructions

Writing ../CAE_setup_local/src/evaluation.py


# Visualization Functions

In [6]:
# write visulization helper function to src as plotting.py
%%writefile ../CAE_setup_local/src/plotting.py

import matplotlib.pyplot as plt
import torch

def plot_baseline_history(baseline_loss, to_plot_train=False):
    '''
    Plot loss history for the baseline model.

    Args:
        baseline_loss (dict): dictionary with 'epoch', 'train', and 'validation' lists
        to_plot_train (bool): if True, also plot training loss
    '''

    color=plt.get_cmap('tab10').colors

    # optionally plot training losses
    if to_plot_train:
        plt.plot(baseline_loss['epoch'], baseline_loss['train'], label='Base model (training loss)', color=color[0], linestyle='--')

    # plot validation losses
    plt.plot(baseline_loss['epoch'], baseline_loss['validation'], label='Base model (validation loss)', color=color[0], linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Baseline Model Loss')
    plt.legend()


def plot_experiment_history(loss_list, label_list, title, to_plot_train=False):
    '''
    Plot loss curves for multiple models.

    Args:
        loss_list (list of dict): list of loss history dictionaries (per model)
        label_list (list of str): list of model names (same length as loss_list)
        title (str): title for the plot
        to_plot_train (bool): if True, also plot training loss curves

    Each dictionary in loss_list must contain:
        - 'epoch': list of epoch numbers
        - 'train': list of training losses (optional if to_plot_train=False)
        - 'validation': list of validation losses
    '''

    color=plt.get_cmap('tab10').colors

    # loop over each loss history in the list
    for i, (loss_history, label) in enumerate(zip(loss_list, label_list)):
        # optionally plot training losses
        if to_plot_train:
            plt.plot(loss_history['epoch'], loss_history['train'], label=label + ' (training loss)', color=color[i+1], linestyle='--')

        # plot validation losses
        plt.plot(loss_history['epoch'], loss_history['validation'], label=label + ' (validation loss)', color=color[i+1], linewidth=2)

    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()


def plot_digits_row(images, labels=None, title=None, cmap='magma', figsize=(15, 3)):
    '''
    Display a row of digit images side by side.

    Args:
        images (numpy array or torch tensor): array of images (n_images, height, width)
        labels (list or array, optional): optional list of labels to display as titles
        title (str, optional): overall title for the plot
        cmap (str): matplotlib colormap for image display
        figsize (tuple): figure size for the plot
    '''

    n_images = images.shape[0]

    fig, axes = plt.subplots(1, n_images, figsize=figsize)

    for idx, ax in enumerate(axes.flat):
        # display each image
        ax.imshow(images[idx], cmap=cmap)
        ax.axis('off')

        # optionally set image label as title
        if labels is not None:
            ax.set_title(str(labels[idx]), fontsize=20)

    # optionally set a main title for the plot
    if title is not None:
        plt.suptitle(title, y=1, fontsize=30)

    plt.tight_layout()
    plt.show()

    # separator
    print('\n ')


def plot_experiment_reconstructions(reconstructions, labels, title_list):
    '''
    Plot reconstruction results for multiple models.

    Args:
        reconstructions (list of torch.Tensor): reconstructed outputs from models
        labels (list or array): labels for each image
        title_list (list of str): titles to display for each model
    '''
    for recon, title in zip(reconstructions, title_list):
        plot_digits_row(
            recon.squeeze(),
            labels,
            title=title + ' reconstructed digits'
        )

Writing ../CAE_setup_local/src/plotting.py


# Exporting Functions

In [10]:
# write fil exporting fucntions to src as export.py
%%writefile ../CAE_setup_local/src/export.py

import os
import torch
import shutil
import subprocess

def save_experiment_files(
    experiment_name,
    models,
    losses,
    reconstructions,
    description_text,
    local_path_root='/content'
):
    """
    Save experiment files: model weights, loss history, reconstructions, and description.

    Args:
        experiment_name (str): e.g., "experiment_2"
        models (list): list of trained model objects
        losses (list): list of loss history objects
        reconstructions (list): list of reconstructed image tensors
        description_text (str): plain-text description
        local_path_root (str): where to create the export folder (default: '/content')
    """

    export_folder = os.path.join(local_path_root, f'CAE_{experiment_name}_local')
    os.makedirs(export_folder, exist_ok=True)

    for idx, (model, loss, recon) in enumerate(zip(models, losses, reconstructions)):
        torch.save(model.state_dict(), os.path.join(export_folder, f'{experiment_name}_model_{idx+1}.pth'))
        torch.save(loss, os.path.join(export_folder, f'{experiment_name}_loss_{idx+1}.pth'))
        torch.save(recon, os.path.join(export_folder, f'{experiment_name}_reconstruction_{idx+1}.pth'))

    with open(os.path.join(export_folder, f'{experiment_name}_description.txt'), 'w') as f:
        f.write(description_text.strip())

    print(f"✅ Saved {experiment_name} files to: {export_folder}")


def export_experiment_files(experiment_name, model_count,
                            local_root='/content',
                            repo_root='/content/CAE-MNIST'):
    """
    Copies experiment output files from local folder to Git repo and pushes them.

    Args:
        experiment_name (str): e.g. "experiment_2"
        model_count (int): number of models/files to export
        local_root (str): path where local files are stored (default: /content)
        repo_root (str): path to the cloned Git repo (default: /content/CAE-MNIST)
    """

    # Define folders
    local_export_folder = os.path.join(local_root, f'CAE_{experiment_name}_local')
    git_output_folder = os.path.join(repo_root, 'outputs', f'{experiment_name}_files')
    os.makedirs(git_output_folder, exist_ok=True)

    # Gather file names
    files_to_copy = []

    # Collect filenames to copy
    for idx in range(0, model_count):
      files_to_copy.append(f'{experiment_name}_model_{idx+1}.pth')
      files_to_copy.append(f'{experiment_name}_loss_{idx+1}.pth')
      files_to_copy.append(f'{experiment_name}_reconstruction_{idx+1}.pth')

    # Add description
    files_to_copy.append(f'{experiment_name}_description.txt')

    # Copy files into Git folder
    for file in files_to_copy:
        shutil.copy2(
            os.path.join(local_export_folder, file),
            os.path.join(git_output_folder, file)
        )

    # Git add, commit, push
    os.chdir(repo_root)
    os.system(f'git add outputs/{experiment_name}_files/*')
    os.system(f'git commit -m "Add {experiment_name}: models, losses, reconstructions, and description" || echo "Nothing to commit"')
    os.system('git push origin main')

    print(f"✅ Exported {experiment_name} files to: outputs/{experiment_name}_files/")

Overwriting ../CAE_setup_local/src/export.py


# Push SRC to GitHub

In [11]:
!git config --global user.email "vladislav.yushkevich.uve@gmail.com"
!git config --global user.name "vlad_uve"

In [12]:
!git clone https://vlad-uve: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@github.com/vlad-uve/CAE-MNIST.git

Cloning into 'CAE-MNIST'...
remote: Enumerating objects: 145, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (116/116), done.[K
remote: Total 145 (delta 59), reused 80 (delta 24), pack-reused 0 (from 0)[K
Receiving objects: 100% (145/145), 4.09 MiB | 25.69 MiB/s, done.
Resolving deltas: 100% (59/59), done.


In [13]:
!rm -r ./CAE-MNIST/src
!cp -r ../CAE_setup_local/src ./CAE-MNIST/

In [14]:
%cd /content/CAE-MNIST
!git add -A
!git commit -m "Update modularized src from CAE_setup_local"
!git push origin main

/content/CAE-MNIST
[main 8fe9a60] Update modularized src from CAE_setup_local
 1 file changed, 82 insertions(+)
 create mode 100644 src/export.py
Enumerating objects: 6, done.
Counting objects: 100% (6/6), done.
Delta compression using up to 2 threads
Compressing objects: 100% (4/4), done.
Writing objects: 100% (4/4), 1.29 KiB | 1.29 MiB/s, done.
Total 4 (delta 2), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
To https://github.com/vlad-uve/CAE-MNIST.git
   6ba40ab..8fe9a60  main -> main


# Push Git Ignore to GitHub

In [None]:
# write gitignore
%%writefile /content/CAE-MNIST/.gitignore

%%writefile /content/CAE-MNIST/.gitignore
# Ignore everything in outputs/ by default
outputs/**

# Allow experiment folders with valid files
!outputs/**/
!outputs/**/*.pth
!outputs/**/*.pt
!outputs/**/*.txt
!outputs/**/*.json
!outputs/**/*.csv
!outputs/**/*.png
!outputs/**/*.jpg


# Ignore all __pycache__ and checkpoint junk
__pycache__/
.ipynb_checkpoints/
*.pyc
*.pyo
*.pyd

Overwriting /content/CAE-MNIST/.gitignore


In [None]:
%cd /content/CAE-MNIST
!git add .gitignore
!git commit -m "Add global .gitignore for structured experiment outputs"
!git push origin main

/content/CAE-MNIST
[main 9f31b41] Add global .gitignore for structured experiment outputs
 1 file changed, 1 insertion(+)
Enumerating objects: 5, done.
Counting objects: 100% (5/5), done.
Delta compression using up to 2 threads
Compressing objects: 100% (3/3), done.
Writing objects: 100% (3/3), 341 bytes | 341.00 KiB/s, done.
Total 3 (delta 2), reused 0 (delta 0), pack-reused 0
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
To https://github.com/vlad-uve/CAE-MNIST.git
   4d9d373..9f31b41  main -> main
