# Variational AutoEncoder Digits Example

Check the README for an introduction to the project and how to get started!

## Imports and Setup

In [None]:
import os
import numpy as np
import pandas as pd

os.environ["KERAS_BACKEND"] = "tensorflow"

# ml dependencies
import tensorflow as tf
import keras
from keras import ops
from keras import layers

# mlflow dependencies
import mlflow
from mlflow import MlflowClient
from pprint import pprint

In [None]:
# create the model directory for saving outputs
model_dir = './model-dir'
os.makedirs(model_dir, exist_ok = True)

env_name = "digits_env" # <name of your env>

## Create the ML Model

### Create sampling layer

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

### Build the encoder

In [None]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

### Build the decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

### Define the VAE as a `Model` with a custom `train_step`

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = ops.mean(
                ops.sum(
                    keras.losses.binary_crossentropy(data, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

## DVC

[Data Version Control](https://dvc.org/doc/user-guide) is an open-source tool for handlin machine learning projects. It helps with data management, ML pipeline automation, and experiment management, making projects reproducible and collaboration better. DVC integrates with Git to track changes in your data and models, enabling you to version control your machine learning workflow in a way similar to how you manage code.

-------------------------------------------------------------------------------------------------------------------------------------

In the lines below, the `&&` symbol is used multiple times. This symbol is originally a logical operator (the command on the right will only run if the command on the left executes successfully). However, when using `!` in a Jupyter notebook, the Linux commands are executed within the directory where the notebook is currently located. This behavior prevents commands from being run in separate directories if `!cd` is on its own line. Using `&&` ensures that the `dvc` commands are executed within the `dvc` submodule directory, without affecting the repository where the notebook resides.

In [None]:
dvc_repo_link = "git@github.com:oobielodan/digits_dvc.git" # <ssh link to the repo you set aside for dvc>
dvc_storage = "/demo-bucket" # <complete path to the mounted storage you have set up for dvc>

In [None]:
# grab your dvc repository -> the --force flag allows for this to still run if the submodule had already been created at a prior time
!git submodule add --force "{dvc_repo_link}"

In [None]:
dvc_repo = "digits_dvc" # <name of the repository/submodule you just added for dvc> -> should appear as a folder in the current directory

In [None]:
# DVC initialization and storage set up
!cd "{dvc_repo}" && dvc init
!cd "{dvc_repo}" && dvc remote add -d dvcstorage "{dvc_storage}"

In [None]:
# initial commit to git
!cd "{dvc_repo}" && git add .
!cd "{dvc_repo}" && git commit -m "loaded dependencies, mkdir -p, DVC init"

## MLFlow
MLflow is designed to help simplify the ML workflow, assisting users throughout the various stages of development and deployment. In this notebook, we use its autologging capabilities to log information on the model, its parameters and results, etc. Documentation and more information can be found at [the MLFlow website](https://mlflow.org/docs/latest/index.html).

-------------------------------------------------------------------------------------------------------------------------------------

To get started with MLflow, run `mlflow server --host 127.0.0.1 --port 8080` in the command line. The `mlflow server` command needs to run in the background and therefore cannot be executed directly in a Jupyter notebook, as each cell must complete execution before the next one can run.

### Configuration
*If you used a different host and/or port during initialization, make sure to update the following URIs accordingly.*

In [None]:
# utilize and set up the initialized server for tracking 
client = MlflowClient(tracking_uri = "http://127.0.0.1:8080")
mlflow.set_tracking_uri("http://127.0.0.1:8080")

In [None]:
# view the metadata associated with all the experiments that are currently on the server 
all_experiments = client.search_experiments()
print(all_experiments)

In [None]:
# example for accessing elements from the returned collections of experiments
default_experiment = [
    {"name": experiment.name, "lifecycle_stage": experiment.lifecycle_stage}
    for experiment in all_experiments
    if experiment.name == "Default"
][0]

pprint(default_experiment)

In [None]:
# working on getting the server to display in the notebook --------------------------------------
# !curl http://127.0.0.1:8080
# %%javascript
# alert("JavaScript is working!");
# from IPython.display import IFrame
# IFrame("http://127.0.0.1:8080", 900,500)

### Experiment 1
In Experiment 1, we train the Digit CVAE model on multiple datasets. To create these datasets, we split the original dataset into five equal, randomized parts. After each training session, we save the weights and use them as the starting point for retraining the model on the next dataset.

In [None]:
# provide an experiment description that will appear in the UI
experiment1_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for randomized numbers (0-9) trained separately."
)

# provide searchable tags for the experiment
experiment1_tags = {
    "project_name": "digit-forecasting",
    "model_type": "randomzied",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment1_description,
}

# create the experiment and give it a unique name
digit_experiment1 = client.create_experiment(
    name="Randomize_Model", tags=experiment1_tags
)

### Experiment 2
In Experiment 2, we train the Digit CVAE model on all digit samples simultaneously, without any subsequent retraining using the weights.

In [None]:
# provide an experiment description that will appear in the UI
experiment2_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for numbers (0-9) trained all together."
)

# provide searchable tags for the experiment
experiment2_tags = {
    "project_name": "digit-forecasting",
    "model_type": "all digits",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment2_description,
}

# create the experiment and give it a unique name
digit_experiment2 = client.create_experiment(
    name="Together_Model", tags=experiment2_tags
)

### Experiment 3
In Experiment 3, we revisit the approach used in Experiment 1 - initializing the model with the weights from a previous training session and retraining it from there. However in this experiment, we train the Digit CVAE model sequentially on each of the 10 digits (0–9), one digit at a time. After each training session, we save the weights and use them to retrain the model on the next digit. This approach induces a 'forgetting' effect, where the model gradually loses its ability to recognize previous digits with each subsequent training session.

In [None]:
# provide an experiment description that will appear in the UI
experiment3_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for each of the numbers (0-9) trained separately."
)

# provide searchable tags that define characteristics of the runs that will be in this experiment
experiment3_tags = {
    "project_name": "digit-forecasting",
    "model_type": "sequential",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment3_description,
}

# create the experiment and give it a unique name
digit_experiment3 = client.create_experiment(
    name="Sequenced_Model", tags=experiment3_tags
)

### Experiment Set Up

In [None]:
# save each of the experiment's metadata
digit_experiment1 = mlflow.set_experiment("Randomize_Model")
digit_experiment2 = mlflow.set_experiment("Together_Model")
digit_experiment3 = mlflow.set_experiment("Sequenced_Model")

## Train the VAE

*Make sure that 'vae.weights.h5' does not already exist in the model directory if you want to training from the beginning.*

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
(x_train, Y_train), (x_test, Y_test) = keras.datasets.mnist.load_data()

In [None]:
early_stopping_cb = keras.callbacks.EarlyStopping(patience = 5, restore_best_weights = True) # stops training early if the validation loss does not improve

def train_model(num, model, data, experiment):
    if os.path.exists(os.path.join(model_dir, 'vae.weights.h5')): # if the model has already been trained at least once, load that model
        model.load_weights(os.path.join(model_dir, 'vae.weights.h5'))
    
    mlflow.autolog()
    
    run_name = f"{num}_test" # define a run name for this iteration of training
    artifact_path = f"{num}"  # define an artifact path that the model will be saved to
    
    # initiate the MLflow run context
    with mlflow.start_run(run_name = run_name, experiment_id = experiment) as run:
        
        history = model.fit(data, epochs=30, batch_size=128, callbacks = [early_stopping_cb])
        model.save_weights(os.path.join(model_dir, 'vae.weights.h5')) # save model weights after training

        hist_pd = pd.DataFrame(history.history)
        hist_pd.to_csv(os.path.join(model_dir, f'history_{num}.csv'), index = False)

### Experiment 1
*How to filter mnist data found here: https://stackoverflow.com/questions/51202181/how-do-i-select-only-a-specific-digit-from-the-mnist-dataset-provided-by-keras*

In [None]:
# retraining the model n times
count = 0
n = 5

mnist_digits = np.expand_dims(np.concatenate([x_train, x_test], axis=0), -1).astype("float32") / 255

for arr in np.array_split(mnist_digits, n):
    count += 1
    train_model(f"rand_{count}", vae, arr, digit_experiment3.experiment_id)

In [None]:
# add model 1 to dvc 
!cp ./"{model_dir}"/vae.weights.h5 "{dvc_repo}"/experiment_1.weights.h5
!sh dvcgit.sh experiment_1.weights.h5 "digit experiment 1" "{dvc_repo}" "{env_name}"

!rm "{dvc_storage}"/experiment_1.weights.h5
!rm ./"{model_dir}"/vae.weights.h5

### Experiment 2

In [None]:
# train all numbers at the same time
train_model("all", vae, mnist_digits, digit_experiment2.experiment_id)

In [None]:
# add model 2 to dvc 
!cp ./"{model_dir}"/vae.weights.h5 "{dvc_repo}"/experiment_2.weights.h5
!sh dvcgit.sh experiment_2.weights.h5 "digit experiment 2" "{dvc_repo}" "{env_name}"

!rm "{dvc_repo}"/experiment_2.weights.h5
!rm ./"{model_dir}"/vae.weights.h5

### Experiment 3

In [None]:
# training one number at a time
for num in np.arange(10):
    train_filter = np.where(Y_train == num)
    test_filter = np.where(Y_test == num)
    
    x_trn = x_train[train_filter]
    x_tst = x_test[test_filter]
    
    mnist_digits = np.expand_dims(np.concatenate([x_trn, x_tst], axis=0), -1).astype("float32") / 255
    train_model(num, vae, mnist_digits, digit_experiment1.experiment_id)

In [None]:
# add model 3 to dvc 
!cp ./"{model_dir}"/vae.weights.h5 "{dvc_repo}"/experiment_3.weights.h5
!sh dvcgit.sh experiment_3.weights.h5 "digit experiment 3" "{dvc_repo}" "{env_name}"

!rm "{dvc_repo}"/experiment_3.weights.h5
!rm ./"{model_dir}"/vae.weights.h5

------------------------------------------------------------------------------------------------
*`dvcgit.sh` is a script used for dvc and git tracking. The correct call is as follows (all arguments are required):* `sh dvcgit.sh <file_name> <commit_message> <dvc_repo_name> <conda_env_name>`

## Display a grid of reconstructed digits in the latent space

In [None]:
import matplotlib.pyplot as plt

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose=0)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

plot_latent_space(vae)

## Display how the latent space clusters digits

In [None]:
def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data, verbose=0)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()

(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(vae, x_train, y_train)