In [282]:
import os
import git
import sys
import glob
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from pathlib import Path
from matplotlib import rcParams
from keras.models import Model
from keras.layers import Conv2D, BatchNormalization, Conv2DTranspose
from keras.layers import Input, Layer, Lambda, Flatten, Reshape
from keras.layers import Multiply, Add
from keras.losses import binary_crossentropy, mse
from keras.optimizers import Adam
from keras.utils import Sequence
from keras.preprocessing.image import load_img, img_to_array
from keras import backend as K

# Find project root directory and file path constants
repo = git.Repo('.', search_parent_directories=True)
PROJECT_DIR = os.path.dirname(repo.git_dir)
sys.path.append(PROJECT_DIR)
from config import DATA_DIR, CELEB_A_DIR

sns.set(context='notebook', style='whitegrid')
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = 'times new roman'

%config InlineBackend.figure_format = 'retina'
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define architecture

In [283]:
class Variational(Layer):
    
    def __init__(self, *args, **kwargs):
        super().__init__(**kwargs)
        
    def build(self, input_shape):
        super().build(input_shape)
    
    def call(self, x):
        assert(isinstance(x, list))
        z_mu, z_log_sigma = x
        eps = K.random_normal(K.shape(z_log_sigma))
        z = Add()([z_mu, Multiply()([K.exp(z_log_sigma), eps])])
        return z
    
    def compute_output_shape(self, input_shape):
        assert(isinstance(input_shape, list))
        z_mu_shape, z_log_sigma_shape = input_shape
        assert(z_mu_shape == z_log_sigma_shape)
        return z_mu_shape

In [284]:
def make_latent_tensors(x):
    a = Conv2D(32, 4, strides=2, padding='same', activation='relu')(x)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(64, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(256, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(512, 4, activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    
    z_mu = Conv2D(32, 1)(a)
    z_mu = Flatten()(z_mu)
    z_log_sigma = Conv2D(32, 1)(a)
    z_log_sigma = Flatten()(z_log_sigma)
    z = Variational()([z_mu, z_log_sigma])
    return z, z_mu, z_log_sigma

In [285]:
def make_reconstruction_tensors(z):
    a = Reshape((1, 1, 32))(z)
    a = Conv2DTranspose(
            512, 1, strides=1, padding='valid', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            256, 4, strides=1, padding='valid', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            64, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            32, 4, strides=2, padding='same', activation='relu')(a)
    y = Conv2DTranspose(
            3, 4, strides=2, padding='same', activation='sigmoid')(a)
    return y

In [301]:
BETA = 0.1

K.clear_session()

x = Input(shape=(128, 128, 3), name='x')
z, z_mu, z_log_sigma = make_latent_tensors(x)
encoder = Model(inputs=x, outputs=[z, z_mu, z_log_sigma], name='encoder')

z_ = Input(shape=(1, 1, 32))
y_ = make_reconstruction_tensors(z_)
decoder = Model(inputs=z_, outputs=y_, name='decoder')

y = decoder(encoder(x)[0])
vae = Model(inputs=x, outputs=y, name='vae')

reconstruction_loss = mse(x, y)
reconstruction_loss *= 3
reconstruction_loss = K.sum(reconstruction_loss, axis=-1)
reconstruction_loss = K.sum(reconstruction_loss, axis=-1)
reconstruction_loss = K.mean(reconstruction_loss)
latent_loss = 1 + 2*z_log_sigma - K.square(z_mu) - K.exp(2*z_log_sigma)
latent_loss = -0.5 * K.sum(latent_loss, axis=-1)
latent_loss = K.mean(latent_loss)
loss = reconstruction_loss + BETA * latent_loss

vae.add_loss(loss)
vae.compile(optimizer=Adam(lr=1e-5), metrics=[latent_loss, reconstruction_loss])

In [297]:
vae.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
x (InputLayer)               (None, 128, 128, 3)       0         
_________________________________________________________________
encoder (Model)              [(None, 32), (None, 32),  3087392   
_________________________________________________________________
decoder (Model)              (None, 128, 128, 3)       3070819   
Total params: 6,158,211
Trainable params: 6,153,795
Non-trainable params: 4,416
_________________________________________________________________


# Load data

In [279]:
def read_img(file):
    img = load_img(file)
    img = img_to_array(img)
    img /= 255.
    return img

def crop_square(img, side_length=128):
    height, width, num_channels = img.shape
    
    # Crop image to square
    extra_padding = (max(height, width) - min(height, width)) // 2
    if height > width:
        img = img[extra_padding:-extra_padding]
    elif height < width:
        img = img[:, extra_padding:-extra_padding]
        
    # Zoom
    extra_padding = (min(height, width) - side_length) // 2
    assert(extra_padding >= 0)
    img = img[extra_padding:-extra_padding, extra_padding:-extra_padding]
    return img

class ImageDataGenerator(Sequence):
    
    def __init__(
            self, data_dir, batch_size=32, shuffle=True,
            filetype='jpg', square_crop_length=128):
        if isinstance(data_dir, str):
            data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.files = list(data_dir.glob('*.{}'.format(filetype)))
        self.num_samples = len(self.files)
        self.square_crop_length = square_crop_length
        if self.shuffle:
            self.files = np.random.permutation(self.files).tolist()
    
    def __getitem__(self, index):
        idx_start = index * self.batch_size
        idx_end = (index + 1) * self.batch_size
        batch_files = self.files[idx_start:idx_end]
        imgs = [read_img(file) for file in batch_files]
        if self.square_crop_length:
            imgs = [
                crop_square(img, side_length=self.square_crop_length)
                for img in imgs]
        imgs = np.array(imgs)
        imgs = (imgs, None) # Return in form (x, y)
        return imgs
    
    def __len__(self):
        return self.num_samples // self.batch_size
    
    def on_epoch_end(self):
        if self.shuffle:
            self.files = np.random.permutation(self.files).tolist()

In [280]:
datagen = ImageDataGenerator(CELEB_A_DIR, batch_size=32, shuffle=True)

# Train model

In [302]:
vae.fit_generator(datagen)

Epoch 1/1
  32/6331 [..............................] - ETA: 2:58:28 - loss: 4402.0537

KeyboardInterrupt: 