In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from IPython import display

import zipfile
import urllib.request
import os
import random

In [2]:
BATCH_SIZE = 256
LATENT_DIM = 256
EPOCHS = 100
IMG_SIZE = 64

In [3]:
# make the data directory
try:
  os.mkdir('/tmp/anime')
except OSError:
  pass

# download the zipped dataset to the data directory
data_url = "https://storage.googleapis.com/learning-datasets/Resources/anime-faces.zip"
data_file_name = "animefaces.zip"
urllib.request.urlretrieve(data_url, data_file_name)

('animefaces.zip', <http.client.HTTPMessage at 0x7f9c5a0ef250>)

In [4]:
zip_ref = zipfile.ZipFile(data_file_name, 'r')
zip_ref.extractall("anime")
zip_ref.close()

In [3]:
def get_dataset_slice_paths(image_dir):
  '''returns a list of paths to the image files'''
  image_file_list = os.listdir(image_dir)
  image_paths = [os.path.join(image_dir, fname) for fname in image_file_list]

  return image_paths

def map_image(image_filename):
  image_raw = tf.io.read_file(image_filename)
  image = tf.image.decode_jpeg(image_raw)
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  image = image/255.0
  image = tf.reshape(image, (IMG_SIZE, IMG_SIZE, 3))
  return image

In [4]:
paths = get_dataset_slice_paths("/tmp/anime/images/")
random.shuffle(paths)

path_len = len(paths)

train_len = int(0.8*path_len)
train_path = paths[:train_len+1]
validation_path= paths[train_len+1:]

train_dataset = tf.data.Dataset.from_tensor_slices((train_path))
train_dataset = train_dataset.map(map_image)
train_dataset = train_dataset.shuffle(4096).batch(BATCH_SIZE)

validation_dataset = tf.data.Dataset.from_tensor_slices((validation_path))
validation_dataset = validation_dataset.map(map_image)
validation_dataset = validation_dataset.batch(BATCH_SIZE)

In [5]:
print(f'No. of batches in training set: {len(train_dataset)}')
print(f'No. of batches in validation set: {len(validation_dataset)}')

No. of batches in training set: 199
No. of batches in validation set: 50


Build Model

In [6]:
class Reparametrize(tf.keras.layers.Layer):
  def call(self, inputs):
    mu, sigma = inputs
    batch = tf.shape(sigma)[0]
    dim = tf.shape(sigma)[1]
    eps = tf.keras.backend.random_normal(shape = (batch, dim))
    z = eps * tf.exp(0.5*sigma) + mu

    return z

In [7]:
def encoder_layers(inputs, latent_dim):
  conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, strides = 2,
                                 activation = "relu", padding = "same")(inputs)
  bn1 = tf.keras.layers.BatchNormalization()(conv1)

  conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 2,
                                 activation = "relu", padding = "same")(bn1)
  bn2 = tf.keras.layers.BatchNormalization()(conv2)

  flat = tf.keras.layers.Flatten()(bn2)
  d1 = tf.keras.layers.Dense(20, activation = "relu")(flat)
  bn3 = tf.keras.layers.BatchNormalization()(d1)

  mu = tf.keras.layers.Dense(latent_dim, activation = "relu")(bn3)
  sigma = tf.keras.layers.Dense(latent_dim, activation = "relu")(bn3)

  return mu, sigma, bn2.shape

In [8]:
def encoder_model(input_shape, latent_dim):
  inputs = tf.keras.layers.Input(shape = input_shape)
  mu, sigma, conv_shape = encoder_layers(inputs, latent_dim)
  z = Reparametrize()((mu, sigma))

  model = tf.keras.Model(inputs = inputs, outputs = [mu, sigma, z])

  return model, conv_shape

In [9]:
def decoder_layers(conv_shape, inputs):
  units = conv_shape[1]*conv_shape[2]*conv_shape[3]

  x = tf.keras.layers.Dense(units, activation = "relu")(inputs)
  x = tf.keras.layers.BatchNormalization()(x)

  x = tf.keras.layers.Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)

  conv1 = tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = 3, strides = 2,
                                 activation = "relu", padding = "same")(x)
  bn1 = tf.keras.layers.BatchNormalization()(conv1)

  conv2 = tf.keras.layers.Conv2DTranspose(filters = 64, kernel_size = 3, strides = 2,
                                 activation = "relu", padding = "same")(bn1)
  bn2 = tf.keras.layers.BatchNormalization()(conv2)

  conv3 = tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = 3, strides = 1,
                                 activation = "sigmoid", padding = "same")(bn2)

  return conv3

In [10]:
def decoder_model(latent_dim, conv_shape):
  inputs = tf.keras.layers.Input(shape = (latent_dim,))
  outputs = decoder_layers(conv_shape, inputs)
  model = tf.keras.Model(inputs = inputs, outputs = outputs)

  return model

In [71]:
def kl_divergence_loss(mu, sigma):
  kl_loss = -0.5 * tf.reduce_mean(1 + sigma - tf.square(mu) - tf.math.exp(sigma), axis = 1)

  return kl_loss

In [20]:
def vae_model(encoder, decoder, input_shape, latent_dim):
  inputs = tf.keras.layers.Input(input_shape)

  mu, sigma, z = encoder(inputs, latent_dim)
  outputs = decoder(z)
  kl_loss = kl_divergence_loss(mu, sigma)

  model = tf.keras.Model(inputs = inputs, outputs = outputs)
  model.add_loss(kl_loss)

  return model

In [21]:
def get_models(input_shape, latent_dim):
  encoder, conv_shape = encoder_model(input_shape, latent_dim)
  decoder = decoder_model(latent_dim, conv_shape)
  vae = vae_model(encoder, decoder, input_shape, latent_dim)
  
  return encoder, decoder, vae

In [72]:
encoder, decoder, vae = get_models(input_shape = (64, 64, 3), latent_dim = LATENT_DIM)

In [73]:
optimizer = tf.keras.optimizers.Adam()
loss_metric = tf.keras.metrics.Mean()
mse_loss = tf.keras.losses.MeanSquaredError()

In [74]:
@tf.function
def train_step(model, x, optimizer, mse_loss):
  with tf.GradientTape() as tape:
    pred = model(x)
    
    loss = mse_loss(x, pred)*12288 # (64x64x3)
    #Scaler is multiploed because we want to take mean on batch_size
    # mse calculates. se = (x-x_pred)**2, se.shape = (batch_size*height*width*channels,)
    # so mse = se/(batch_size*height*width*channels)
    # But we want to normalize by batch_size, so multiply back by (height*width*channels)   
    loss +=sum(vae.losses)

  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))

  return loss

In [75]:
def generate_and_save_images(model, epoch, step, test_input):
  """Helper function to plot our 16 images

  Args:

  model -- the decoder model
  epoch -- current epoch number during training
  step -- current step number during training
  test_input -- random tensor with shape (16, LATENT_DIM)
  """
  predictions = model.predict(test_input)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      img = predictions[i, :, :, :] * 255
      img = img.astype('int32')
      plt.imshow(img)
      plt.axis('off')

  # tight_layout minimizes the overlap between 2 sub-plots
  fig.suptitle("epoch: {}, step: {}".format(epoch, step))
  #plt.savefig('image_at_epoch_{:04d}_step{:04d}.png'.format(epoch, step))
  plt.show()

In [76]:
random_vector_for_generation = tf.random.normal(shape=[16, LATENT_DIM])

In [79]:
for epoch in range(EPOCHS):
  for step, x_batch_train in enumerate(train_dataset):
    loss = train_step(vae, x_batch_train, optimizer, mse_loss)

    loss_metric(loss)

    if step % 100 == 0:
      display.clear_output(wait=False)    
      generate_and_save_images(decoder, epoch, step, random_vector_for_generation)
      print('Epoch: %s step: %s mean loss = %s' % (epoch, step, loss_metric.result().numpy()))