In [None]:
import os
import numpy as np
import cv2
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Lambda,Input,Dense,Flatten,Conv2D,Conv2DTranspose
from keras.layers import Activation,BatchNormalization,Reshape
from keras.models import Model
from keras.losses import mse, binary_crossentropy
from keras.callbacks import Callback
from keras import backend as K

original_path = os.getcwd()

## Checking dataset
[***Source: Standford Dogs dataset***](http://vision.stanford.edu/aditya86/ImageNetDogs/)

In [None]:
os.chdir('data')
print('number of breeds: {}'.format(len(os.listdir())))
total_num = 0
for breed in os.listdir():
    os.chdir(breed)
    total_num += len(os.listdir())
    os.chdir('..')
print('total number of img: {}'.format(total_num))
os.chdir(original_path)

## Folder/file name cleaning

In [None]:
def name_cleaning(flag=False):
    if flag:
        os.chdir('data') # root data folder
        for breed in os.listdir():
            if os.path.isdir(breed):
                os.rename(breed,breed.split('-')[1]) ## cleaning up folder names
        for breed in os.listdir(): 
            os.chdir(breed)
            current_list = os.listdir(os.getcwd())
            for i in range(len(os.listdir())):
                original_name = current_list[i]
                new_name = breed + '_{:04d}'.format(i+1) + os.path.splitext(original_name)[-1] ## cleaning up file names
                if not os.path.exists(new_name):
                    os.rename(original_name,new_name)
            os.chdir('..')
        os.chdir(original_path)
    else:
        print('Folder and file names are already processed')

In [None]:
name_cleaning()

## Display sample images

In [None]:
np.random.seed(seed=0)
n_samples = 5
n_breed = 5
fig, rows = plt.subplots(n_breed, n_samples, figsize = (4*n_samples, 3*n_breed))

os.chdir('data')
for row,breed in zip(rows,np.random.choice(os.listdir(),n_breed,replace=False)):
    row[int(np.floor(n_samples/2))].set_title(breed,fontsize=25)
    os.chdir(breed)
    for col_ax,img in zip(row,np.random.choice(os.listdir(),n_samples,replace=False)):
        rand_img = cv2.imread(img)
        rand_img = cv2.cvtColor(rand_img,cv2.COLOR_BGR2RGB)
        col_ax.imshow(rand_img)
        col_ax.axis('off')
    os.chdir('..')
plt.subplots_adjust(left=0.2, wspace=0.02)
os.chdir(original_path)

## Load/pre-processing images

In [None]:
os.chdir('data')
breed_list = os.listdir()
idx_list = list(range(len(os.listdir())))
label_dict = {k: v for k, v in zip(breed_list,idx_list)} 
os.chdir(original_path)

In [None]:
img_rows = 224
img_cols = 224
img_list = []
label_list = []
os.chdir('data')
for breed in os.listdir():
    os.chdir(breed)
    current_list = os.listdir()
    for img in current_list:
        img_in = cv2.imread(img)    
        img_in = cv2.cvtColor(img_in,cv2.COLOR_BGR2RGB)    
        img_in = cv2.resize(img_in,(img_rows,img_cols),cv2.INTER_AREA)
        img_list.append(img_in)
        label_list.append(label_dict[breed])
    os.chdir('..')
os.chdir(original_path)

img_data = np.array(img_list).astype(np.float32)
img_label = np.array(label_list)
img_data /= 255.

In [None]:
print('training data shape: {}'.format(img_data.shape))
print('label shape: {}'.format(img_label.shape))

In [None]:
from sklearn.utils import shuffle
X_shuffled, y_shuffled = shuffle(img_data,img_label,random_state=0)

In [None]:
from sklearn.model_selection import train_test_split
X_train,X_valid,y_train,y_valid = train_test_split(X_shuffled,y_shuffled,test_size=0.2,random_state=0,stratify=y_shuffled)
X_train = X_train[:10000]
X_valid = X_valid[:1000]
y_train = y_train[:10000]
y_valid = y_valid[:1000]
print('X_train shape: {}'.format(X_train.shape))
print('y_train shape: {}'.format(y_train.shape))
print('X_valid shape: {}'.format(X_valid.shape))
print('y_valid shape: {}'.format(y_valid.shape))

In [None]:
def sampling(arg):
    arg = [z_mean,z_log_var]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], dim),mean=0.0, stddev=1.0) # reparameterization trick
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [None]:
img_rows = X_train.shape[1]
img_cols = X_train.shape[2]
channel = X_train.shape[3]
input_shape = (img_rows,img_cols,channel)
batch_size = 128
latent_dim = 2
epochs = 5

In [None]:
inputs = Input(shape=input_shape)

x = Conv2D(16,3,strides=2,padding='same')(inputs)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(32,3,strides=1,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64,3,strides=2,padding='same',activation='relu')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
before_flatten_shape = K.int_shape(x)

x = Flatten()(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)

z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

encoder = Model(inputs, [z_mean,z_log_var,z])
encoder.summary()

In [None]:
latent_inputs = Input(shape=(latent_dim,))
x = Dense(32, activation='relu')(latent_inputs)
x = Dense(64, activation='relu')(x)
x = Dense(before_flatten_shape[1]*before_flatten_shape[2]*before_flatten_shape[3], activation='relu')(x)
x = Reshape((before_flatten_shape[1],before_flatten_shape[2],before_flatten_shape[3]))(x)
x = Conv2DTranspose(64,3,strides=2,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

x = Conv2DTranspose(32,3,activation='relu',strides=1,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

x = Conv2DTranspose(16,3,activation='relu',strides=2,padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

outputs = Conv2DTranspose(1,3,activation='sigmoid',padding='same')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs)
decoder.summary()

In [None]:
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs)

In [None]:
beta = 1 # 1 --> regular VAE
# reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
reconstruction_loss = binary_crossentropy(K.flatten(inputs),K.flatten(outputs))
reconstruction_loss *= img_rows * img_cols

kl_loss = -0.5 * K.sum(beta*(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)), axis=-1)
vae_loss = K.mean(reconstruction_loss + kl_loss)

In [None]:
class NEpochPrint(Callback):
    def __init__(self, display_step):
        self.epoch = 0
        self.display_step = display_step
    def on_epoch_end(self,epoch,logs={}):
        self.epoch += 1        
        if self.epoch == 1 or self.epoch % self.display_step == 0:
            print('Epoch: {}/{} ..... {}: {:.4f} - {}: {:.4f}'.format(self.epoch,
                                                                      self.params['epochs'],
                                                                      self.params['metrics'][0], 
                                                                      logs.get(self.params['metrics'][0]),
                                                                      self.params['metrics'][1],               
                                                                      logs.get(self.params['metrics'][1])))
NEpochPrinter = NEpochPrint(display_step=2)

In [None]:
datagen = ImageDataGenerator()

In [None]:
def train_or_load_weights(flag):
    if flag == 'train':
        hist = vae.fit_generator(datagen.flow(X_train,batch_size=batch_size),
                                 steps_per_epoch = len(X_train)//batch_size,
                                 epochs=epochs,
                                 validation_data=(X_valid, None),
                                 verbose=0,
                                 callbacks=[NEpochPrinter])
        vae.save_weights('dcp_v1.h5')
        return hist
    if flag == 'load':
        vae.load_weights('dcp_v1.h5') # load all the weights for encoder and decoder when loading for vae

In [None]:
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()

In [None]:
hist = train_or_load_weights(flag='train')