<a href="https://colab.research.google.com/github/ayulockin/deepgenerativemodeling/blob/master/Hyperparameter_Sweep_Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setups, Imports and Installations

In [0]:
import tensorflow as tf
print(tf.__version__)

from tensorflow import keras
import tensorflow.keras.backend as K

2.2.0-rc2


In [0]:
%%capture
!pip install wandb

import wandb
from wandb.keras import WandbCallback

!wandb login 69f60a7711ce6b8bbae91ac6d15e45d6b1f1430e

In [0]:
import os
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

from tqdm.notebook import tqdm

# Let's use MNIST Dataset

In [0]:
from keras.datasets import mnist

# MNIST dataset
(train_img, train_label), (test_img, test_label) = mnist.load_data()

image_size = train_img.shape[1]
original_dim = image_size * image_size
train_img = np.reshape(train_img, [-1, original_dim])
test_img = np.reshape(test_img, [-1, original_dim])
train_img = train_img.astype('float32') / 255
test_img = test_img.astype('float32') / 255

## Just dataset info
print("X_train: ", train_img.shape)
print("X_test: ", test_img.shape)

Using TensorFlow backend.


Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
X_train:  (60000, 784)
X_test:  (10000, 784)


# Model

In [0]:
class Autoencoder:
  def __init__(self, input_shape, original_dim, intermediate_dim, latent_space):
    self.input_shape = input_shape
    self.original_dim = original_dim
    self.intermediate_dim = intermediate_dim
    self.latent_space = latent_space
    
  def build_model(self):
    ## Initialize encoder model
    encoder_input, encoder_output = self.build_encoder()
    ## Initialize decoder model
    self.build_decoder()
    ## Join encoder and decoder
    decoder_output = self.decoder(encoder_output)
    ## Build autoencoder model
    return keras.models.Model(inputs=[encoder_input], outputs=[decoder_output])

  def build_encoder(self):
    ## ENCODER
    encoder_input = keras.layers.Input(shape=self.input_shape)
    x = keras.layers.Dense(self.intermediate_dim, activation='relu')(encoder_input)
    ## This is latent space
    encoder_output = keras.layers.Dense(self.latent_space)(x)
    encoder_output = keras.layers.LeakyReLU(0.3)(encoder_output)
    
    self.encoder =  keras.models.Model(inputs=[encoder_input], outputs=[encoder_output])
    return encoder_input, encoder_output

  def build_decoder(self):
    ## DECODER
    decoder_input = keras.layers.Input(shape=self.latent_space)
    x = keras.layers.Dense(self.intermediate_dim, activation='relu')(decoder_input)
    ## This is reconstruction
    decoder_output = keras.layers.Dense(self.original_dim, activation='sigmoid')(x)
    
    self.decoder = keras.models.Model(inputs=[decoder_input], outputs=[decoder_output])

# WandB and Callbacks

In [0]:
class ReconstructionLogger(tf.keras.callbacks.Callback):
  def __init__(self, batch_size):
    self.batch_size = batch_size

    super(ReconstructionLogger, self).__init__()

  def on_epoch_end(self, logs, epoch):
    sample_images = test_img[:self.batch_size]
    
    images = []
    reconstructions = []
    
    for i in range(32):
        reconstruction = self.model.predict(sample_images[i].reshape((1,)+sample_images[i].shape))

        images.append(sample_images[i].reshape(28,28))
        reconstructions.append(reconstruction.reshape(28,28))

    wandb.log({"images": [wandb.Image(image)
                          for image in images]})
    wandb.log({"reconstructions": [wandb.Image(reconstruction)
                          for reconstruction in reconstructions]})

# Sweep Train 

In [0]:
def train():
    # Initialize wandb with a sample project name
    wandb.init(entity='ayush-thakur', project="keras-gan")
    
    (X_train, y_train) = train_img, train_label
    (X_test, y_test) = test_img, test_label

    # Specify the hyperparameter to be tuned along with
    # an initial value
    configs = {
        'latent_space': 2
    }
    
    # Specify the other hyperparameters to the configuration
    config = wandb.config
    config.batch_size = 32
    config.epochs = 10
    
    # Add the config item (layers) to wandb
    if wandb.run:
        wandb.config.update({k: v for k, v in configs.items() if k not in dict(wandb.config.user_items())})
        configs = dict(wandb.config.user_items())
    
    # Define the model
    ae = Autoencoder(input_shape=(784), 
                 original_dim=(784), 
                 intermediate_dim=512,
                 latent_space=wandb.config.latent_space)
    
    model = ae.build_model()
    
    # Compile the model
    model.compile(optimizer='adam',
                  loss='mean_squared_error')
    
    # Train the model
    _ = model.fit(X_train,
              X_train, 
              epochs=config.epochs,
              batch_size=config.batch_size,
              callbacks=[WandbCallback(),
                         ReconstructionLogger(config.batch_size)])

In [0]:
sweep_config = {
  'method': 'grid',
  'parameters': {
      'latent_space': {
          'values': [2, 10, 100]
      }
  }
}

In [0]:
sweep_id = wandb.sweep(sweep_config, entity='ayush-thakur', project="keras-gan")

Create sweep with ID: 17bqwnpv
Sweep URL: https://app.wandb.ai/ayush-thakur/keras-gan/sweeps/17bqwnpv


In [0]:
wandb.agent(sweep_id, function=train)

wandb: Agent Starting Run: dz273fbw with config:
	latent_space: 2
wandb: Agent Started Run: dz273fbw


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
wandb: Agent Finished Run: dz273fbw 

wandb: Agent Starting Run: 5eqplxt0 with config:
	latent_space: 10
wandb: Agent Started Run: 5eqplxt0


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
wandb: Agent Finished Run: 5eqplxt0 

wandb: Agent Starting Run: zzvj4eds with config:
	latent_space: 100
wandb: Agent Started Run: zzvj4eds


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
wandb: Agent Finished Run: zzvj4eds 

