In [59]:
import os
import git
import sys
import glob
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from pathlib import Path
from matplotlib import rcParams
from keras.models import Model
from keras.layers import Conv2D, BatchNormalization, Conv2DTranspose
from keras.layers import Input, Layer, Lambda, Flatten, Reshape
from keras.layers import Multiply, Add, Input, Dense
from keras.losses import binary_crossentropy, mse
from keras.optimizers import Adam
from keras.utils import Sequence
from keras.preprocessing.image import load_img, img_to_array
from keras import backend as K
from keras.metrics import logcosh

# Find project root directory and file path constants
repo = git.Repo('.', search_parent_directories=True)
PROJECT_DIR = os.path.dirname(repo.git_dir)
sys.path.append(PROJECT_DIR)
from config import DATA_DIR, CELEB_A_DIR

sns.set(context='notebook', style='whitegrid')
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = 'times new roman'

%config InlineBackend.figure_format = 'retina'
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define architecture

In [84]:
class Variational(Layer):
    
    def __init__(self, *args, **kwargs):
        super().__init__(**kwargs)
        
    def build(self, input_shape):
        super().build(input_shape)
    
    def call(self, x):
        assert(isinstance(x, list))
        z_mu, z_log_sigma = x
        eps = K.random_normal(K.shape(z_log_sigma))
        z = Add()([z_mu, Multiply()([K.exp(z_log_sigma), eps])])
        return z
    
    def compute_output_shape(self, input_shape):
        assert(isinstance(input_shape, list))
        z_mu_shape, z_log_sigma_shape = input_shape
        assert(z_mu_shape == z_log_sigma_shape)
        return z_mu_shape

In [132]:
def make_latent_tensors(x):
    a = Conv2D(32, 4, strides=2, padding='same', activation='relu')(x)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(64, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(256, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2D(512, 4, activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    
    z_mu = Conv2D(32, 1)(a)
    z_mu = Reshape((32,))(z_mu)
    z_log_sigma = Conv2D(32, 1)(a)
    z_log_sigma = Reshape((32,))(z_log_sigma)
    z = Variational()([z_mu, z_log_sigma])
    return z, z_mu, z_log_sigma

In [133]:
def make_reconstruction_tensors(z):
    a = Reshape((1, 1, -1))(z)
    a = Conv2DTranspose(
            512, 1, strides=1, padding='valid', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            256, 4, strides=1, padding='valid', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            128, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            64, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    a = Conv2DTranspose(
            32, 4, strides=2, padding='same', activation='relu')(a)
    a = BatchNormalization(axis=-1, momentum=0.1, epsilon=1e-5)(a)
    y = Conv2DTranspose(
            3, 4, strides=2, padding='same', activation='sigmoid')(a)
    return y

In [137]:
int(z.shape[-1])

32

In [162]:
encoder.inputs[0]

<tf.Tensor 'x:0' shape=(?, 128, 128, 3) dtype=float32>

In [160]:
decoder.inputs[0]

<tf.Tensor 'input_1:0' shape=(?, 1, 1, 32) dtype=float32>

In [166]:
encoder(x)[0]

<tf.Tensor 'encoder/variational_1/add_2/add:0' shape=(?, 32) dtype=float32>

In [156]:
BETA = 0.1

K.clear_session()

x = Input(shape=(128, 128, 3), name='x')
z, z_mu, z_log_sigma = make_latent_tensors(x)
encoder = Model(inputs=x, outputs=[z, z_mu, z_log_sigma], name='encoder')

z_ = Input(shape=(1, 1, 32))
y_ = make_reconstruction_tensors(z_)
decoder = Model(inputs=z_, outputs=y_, name='decoder')

y = decoder(encoder(x)[0])
vae = Model(inputs=x, outputs=y, name='vae')

reconstruction_loss = mse(x, y)
reconstruction_loss *= 3
reconstruction_loss = K.sum(reconstruction_loss, axis=-1)
reconstruction_loss = K.sum(reconstruction_loss, axis=-1)
reconstruction_loss = K.mean(reconstruction_loss)
latent_loss = 1 + 2*z_log_sigma - K.square(z_mu) - K.exp(2*z_log_sigma)
latent_loss = -0.5 * K.sum(latent_loss, axis=-1)
latent_loss = K.mean(latent_loss)
loss = reconstruction_loss + BETA * latent_loss

In [197]:
def make_loss_fn(f1, f2):
    def loss_fn(y_true, y_pred):
        return f1(y_true, y_pred) + f2(y_true, y_pred) 
    return loss_fn

loss_fn = make_loss_fn(latent_loss, mse)

In [201]:
def make_latent_loss(z_mu, z_log_sigma):
    latent_loss = 1 + 2*z_log_sigma - K.square(z_mu) - K.exp(2*z_log_sigma)
    latent_loss = -0.5 * K.sum(latent_loss, axis=-1)
    latent_loss = K.mean(latent_loss)
    def latent_loss_fn(x, y):
        return latent_loss
    return latent_loss_fn

latent_loss = make_latent_loss(z_mu, z_log_sigma)

In [220]:
type(vae)

keras.engine.training.Model

In [217]:
decoder.inputs[0]

<tf.Tensor 'input_1:0' shape=(?, 1, 1, 32) dtype=float32>

In [202]:
# vae.add_loss(5e-4 * latent_loss)
vae.compile(optimizer=Adam(lr=1e-5), metrics=[mse, latent_loss], loss=[loss_fn])
vae.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
x (InputLayer)               (None, 128, 128, 3)       0         
_________________________________________________________________
encoder (Model)              [(None, 32), (None, 32),  3087392   
_________________________________________________________________
decoder (Model)              (None, 128, 128, 3)       3070947   
Total params: 6,158,339
Trainable params: 6,153,859
Non-trainable params: 4,480
_________________________________________________________________


# Load data

In [203]:
def read_img(file):
    img = load_img(file)
    img = img_to_array(img)
    img /= 255.
    return img

def crop_square(img, side_length=128):
    height, width, num_channels = img.shape
    
    # Crop image to square
    extra_padding = (max(height, width) - min(height, width)) // 2
    if height > width:
        img = img[extra_padding:-extra_padding]
    elif height < width:
        img = img[:, extra_padding:-extra_padding]
        
    # Zoom
    extra_padding = (min(height, width) - side_length) // 2
    assert(extra_padding >= 0)
    img = img[extra_padding:-extra_padding, extra_padding:-extra_padding]
    return img

class ImageDataGenerator(Sequence):
    
    def __init__(
            self, data_dir, batch_size=32, shuffle=True,
            filetype='jpg', square_crop_length=128):
        if isinstance(data_dir, str):
            data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.files = list(data_dir.glob('*.{}'.format(filetype)))
        self.num_samples = len(self.files)
        self.square_crop_length = square_crop_length
        if self.shuffle:
            self.files = np.random.permutation(self.files).tolist()
    
    def __getitem__(self, index):
        idx_start = index * self.batch_size
        idx_end = (index + 1) * self.batch_size
        batch_files = self.files[idx_start:idx_end]
        imgs = [read_img(file) for file in batch_files]
        if self.square_crop_length:
            imgs = [
                crop_square(img, side_length=self.square_crop_length)
                for img in imgs]
        imgs = np.array(imgs)
        imgs = (imgs, imgs) # Return in form (x, y)
        return imgs
    
    def __len__(self):
        return self.num_samples // self.batch_size
    
    def on_epoch_end(self):
        if self.shuffle:
            self.files = np.random.permutation(self.files).tolist()

In [204]:
datagen = ImageDataGenerator(CELEB_A_DIR, batch_size=32, shuffle=True)

# Train model

In [205]:
history = vae.fit_generator(datagen, steps_per_epoch=5)

Epoch 1/1


In [210]:
encoder.inputs[0].name

'x:0'

In [None]:
(25, 20, 128, 128, 3) --> (20, 128 * 5, 128 * 5, 3)

# Original
(25, 20, 128, 128, 3)

# Transpose
(20, 25, 128, 128, 3)

# Reshape
(20, 5, 5, 128, 128, 3)

# Stack
(20, 5, 5 * 128, 128, 3)

# Transpose
(20, 5 * 128, 5, 128, 3)

# Stack
(20, 5 * 128, 5 * 128, 3)

In [252]:
NUM_IMAGES = 24
NUM_TRAVERSAL_POINTS = 20
NUM_ROWS = 4
NUM_COLS = 6
IMG_HEIGHT = 5
IMG_WIDTH = 10

test = np.zeros([NUM_IMAGES, NUM_TRAVERSAL_POINTS, IMG_HEIGHT, IMG_WIDTH])
for i in range(NUM_IMAGES):
    test[i] += i
for i in range(NUM_TRAVERSAL_POINTS):
    test[:, i] += 0.01 * i
    
print(test.shape)

test = test.transpose(1, 0, 2, 3)
print(test.shape)

test = test.reshape(NUM_TRAVERSAL_POINTS, NUM_ROWS, NUM_COLS, IMG_HEIGHT, IMG_WIDTH)
print(test.shape)

test = test.transpose(1, 3, 2, 4, 0)
print(test.shape)

test = np.vstack(test)
print(test.shape)

test = test.transpose(1, 2, 0, 3)
print(test.shape)

test = np.vstack(test)
print(test.shape)

test = test.transpose(2, 1, 0)
print(test.shape)

(24, 20, 5, 10)
(20, 24, 5, 10)
(20, 4, 6, 5, 10)
(4, 5, 6, 10, 20)
(20, 6, 10, 20)
(6, 10, 20, 20)
(60, 20, 20)
(20, 20, 60)


In [258]:
NUM_IMAGES = 24
NUM_TRAVERSAL_POINTS = 20
NUM_ROWS = 4
NUM_COLS = 6
IMG_HEIGHT = 5
IMG_WIDTH = 10

test = np.zeros([NUM_IMAGES, NUM_TRAVERSAL_POINTS, IMG_HEIGHT, IMG_WIDTH])
for i in range(NUM_IMAGES):
    test[i] += i
for i in range(NUM_TRAVERSAL_POINTS):
    test[:, i] += 0.01 * i
print(test.shape)

test = test.transpose(1, 0, 2, 3)
print(test.shape)

test = test.reshape(NUM_TRAVERSAL_POINTS, NUM_ROWS, NUM_COLS, IMG_HEIGHT, IMG_WIDTH)
print(test.shape)

test = test.transpose(0, 1, 3, 2, 4)
print(test.shape)

test = test.reshape(NUM_TRAVERSAL_POINTS, NUM_ROWS * IMG_HEIGHT, NUM_COLS * IMG_WIDTH)
print(test.shape)

(24, 20, 5, 10)
(20, 24, 5, 10)
(20, 4, 6, 5, 10)
(20, 4, 5, 6, 10)
(20, 20, 60)


In [259]:
test[0]

array([[ 0.,  0.,  0., ...,  5.,  5.,  5.],
       [ 0.,  0.,  0., ...,  5.,  5.,  5.],
       [ 0.,  0.,  0., ...,  5.,  5.,  5.],
       ...,
       [18., 18., 18., ..., 23., 23., 23.],
       [18., 18., 18., ..., 23., 23., 23.],
       [18., 18., 18., ..., 23., 23., 23.]])

In [261]:
np.linspace(-4, 4, 25)

array([-4.        , -3.66666667, -3.33333333, -3.        , -2.66666667,
       -2.33333333, -2.        , -1.66666667, -1.33333333, -1.        ,
       -0.66666667, -0.33333333,  0.        ,  0.33333333,  0.66666667,
        1.        ,  1.33333333,  1.66666667,  2.        ,  2.33333333,
        2.66666667,  3.        ,  3.33333333,  3.66666667,  4.        ])

In [275]:
test.shape

(20, 20, 60)

In [281]:
int(z_mu.shape[-1])

32