# Variational AutoEncoder example
*add an introductino here*

## Environment set up
*add set up instructions here*

* `conda env update --file cvae_env.yaml --name <NAME>`
* `source ~/pw/.miniconda3c/etc/profile.d/conda.sh`
* `conda activate <NAME>`
* `python -m ipykernel install --user --name=<NAME> --display-name "Python (<NAME>)"`
* `pip install mlflow`

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import pandas as pd
import tensorflow as tf
import keras
from keras import ops
from keras import layers

# mlflow dependencies
import mlflow
from mlflow import MlflowClient
from pprint import pprint

In [None]:
model_dir = './model_d'
os.makedirs(model_dir, exist_ok = True)

## Create sampling layer

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon

## Build the encoder

In [None]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

## Build the decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

## Define the VAE as a `Model` with a custom `train_step`

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = ops.mean(
                ops.sum(
                    keras.losses.binary_crossentropy(data, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

## DVC

In [None]:
# grab your dvc repository 
!git submodule add --force git@github.com:oobielodan/digits_dvc.git # change to the digits repo that you made

In [None]:
# DVC initialization and storage set up
!cd digits_dvc && dvc init --subdir 
!cd digits_dvc && dvc remote add -d dvcstorage /demo-bucket

In [None]:
# initial commit to git
!cd digits_dvc && git add .
!cd digits_dvc && git commit -m "loaded dependencies, mkdir -p, DVC init"

## MLFlow

`mlflow server --host 127.0.0.1 --port 8080` -> this command must be run on the command line

In [None]:
client = MlflowClient(tracking_uri = "http://127.0.0.1:8080")
mlflow.set_tracking_uri("http://127.0.0.1:8080")

In [None]:
!curl http://127.0.0.1:8080

In [None]:
%%javascript
alert("JavaScript is working!");

In [None]:
from IPython.display import IFrame
IFrame("http://127.0.0.1:8080", 900,500)

In [None]:
all_experiments = client.search_experiments()
print(all_experiments)

In [None]:
default_experiment = [
    {"name": experiment.name, "lifecycle_stage": experiment.lifecycle_stage}
    for experiment in all_experiments
    if experiment.name == "Default"
][0]

pprint(default_experiment)

In [None]:
# Provide an Experiment description that will appear in the UI
experiment1_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for each of the numbers (0-9) trained separately."
)

# Provide searchable tags that define characteristics of the Runs that
# will be in this Experiment
experiment1_tags = {
    "project_name": "digit-forecasting",
    "model_type": "sequential",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment1_description,
}

# Create the Experiment, providing a unique name
digit_experiment1 = client.create_experiment(
    name="Sequenced_Model", tags=experiment1_tags
)

In [None]:
# Provide an Experiment description that will appear in the UI
experiment2_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for numbers (0-9) trained all together."
)

# Provide searchable tags that define characteristics of the Runs that
# will be in this Experiment
experiment2_tags = {
    "project_name": "digit-forecasting",
    "model_type": "all digits",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment2_description,
}

# Create the Experiment, providing a unique name
digit_experiment2 = client.create_experiment(
    name="Together_Model", tags=experiment2_tags
)

In [None]:
# Provide an Experiment description that will appear in the UI
experiment3_description = (
    "This is the digits forecasting project."
    "This experiment contains the digit model for randomized numbers (0-9) trained separately."
)

# Provide searchable tags that define characteristics of the Runs that
# will be in this Experiment
experiment3_tags = {
    "project_name": "digit-forecasting",
    "model_type": "randomzied",
    "team": "digit-ml",
    "project_quarter": "Q3-2024",
    "mlflow.note.content": experiment3_description,
}

# Create the Experiment, providing a unique name
digit_experiment3 = client.create_experiment(
    name="Randomize_Model", tags=experiment3_tags
)

In [None]:
# Sets the current active experiment to the "Apple_Models" experiment and
# returns the Experiment metadata
digit_experiment1 = mlflow.set_experiment("Sequenced_Model")
digit_experiment2 = mlflow.set_experiment("Together_Model")
digit_experiment3 = mlflow.set_experiment("Randomize_Model")

## Train the VAE

*make sure that 'vae.weights.h5' does not already exist in the model directory if you want to training from the beginning*

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
(x_train, Y_train), (x_test, Y_test) = keras.datasets.mnist.load_data()

In [None]:
early_stopping_cb = keras.callbacks.EarlyStopping(patience = 5, restore_best_weights = True) # stops training early if the validation loss does not improve

def train_model(num, model, data, experiment):
    if os.path.exists(os.path.join(model_dir, 'vae.weights.h5')): # if the model has already been trained at least once, load that model
        model.load_weights(os.path.join(model_dir, 'vae.weights.h5'))
        
    history = model.fit(data, epochs=30, batch_size=128, callbacks = [early_stopping_cb])
    model.save_weights(os.path.join(model_dir, 'vae.weights.h5')) # save model weights after training
    
    hist_pd = pd.DataFrame(history.history)
    hist_pd.to_csv(os.path.join(model_dir, f'history_{num}.csv'), index = False)
    
    run_name = f"{num}_test" # define a run name for this iteration of training
    artifact_path = f"{num}"  # define an artifact path that the model will be saved to
    
    # initiate the MLflow run context
    with mlflow.start_run(run_name = run_name, experiment_id = experiment) as run:
        mlflow.log_params({"num": num}) # log the parameters used for the model fit
        # mlflow.log_metrics(history.history) #  log the error metrics that were calculated during validation
        mlflow.keras.save.log_model(model, "model") # log an instance of the trained model for later use

In [None]:
# training one number at a time
for num in np.arange(10):
    train_filter = np.where(Y_train == num)
    test_filter = np.where(Y_test == num)
    
    x_trn = x_train[train_filter]
    x_tst = x_test[test_filter]
    
    mnist_digits = np.expand_dims(np.concatenate([x_trn, x_tst], axis=0), -1).astype("float32") / 255
    train_model(num, vae, mnist_digits, digit_experiment1.experiment_id)

In [None]:
# add this model to dvc 
!cp ./model_d/vae.weights.h5 digits_dvc/experiment_1.weights.h5
!sh dvcgit.sh digits_dvc/experiment_1.weights.h5 "digit experiment 1"

!rm digits_dvc/experiment_1.weights.h5
!rm ./model_d/vae.weights.h5

In [None]:
# train all numbers at the same time
train_model("all", vae, mnist_digits, digit_experiment2.experiment_id)

In [None]:
# add this model to dvc 
!cp ./model_d/vae.weights.h5 digits_dvc/experiment_2.weights.h5
!sh dvcgit.sh digits_dvc/experiment_2.weights.h5 "digit experiment 2"

!rm digits_dvc/experiment_2.weights.h5
!rm ./model_d/vae.weights.h5

In [None]:
# retraining the model n times
count = 0
n = 5

for arr in np.array_split(mnist_digits, n):
    count += 1
    train_model(f"rand_{count}", vae, arr, digit_experiment3.experiment_id)

In [None]:
# add this model to dvc 
!cp ./model_d/vae.weights.h5 digits_dvc/experiment_3.weights.h5
!sh dvcgit.sh digits_dvc/experiment_3.weights.h5 "digit experiment 3"

!rm digits_dvc/experiment_3.weights.h5
!rm ./model_d/vae.weights.h5

*how to filter found here: https://stackoverflow.com/questions/51202181/how-do-i-select-only-a-specific-digit-from-the-mnist-dataset-provided-by-keras*

## Display a grid of reconstructed digits in the latent space

In [None]:
import matplotlib.pyplot as plt

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample, verbose=0)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

plot_latent_space(vae)

## Display how the latent space clusters digits

In [None]:
def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data, verbose=0)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()

(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(vae, x_train, y_train)