# Imports

In [1]:
from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from helper import DataLoader
import numpy as np
import os
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
import pydot
from keras.callbacks import TensorBoard, ModelCheckpoint
from glob import glob
from tqdm import tqdm_notebook
import scipy
import imageio
from tensorlayer.prepro import *
import tensorlayer as tl

Using TensorFlow backend.


# Params

In [2]:
hr_shape = (224,224,3)
img_hr = Input(shape=hr_shape)
lr_shape = (56,56,3)
img_lr = Input(shape=lr_shape)

# Hyper Params

In [3]:
batchSize = 16
optimizer = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0, amsgrad=False)
epochs = 2000
ni = np.sqrt(batchSize)

Instructions for updating:
Colocations handled automatically by placer.


# Generator

In [4]:
def build_generator(img_lr,lr_shape):
    def residual_block(layer_input, filters):
        """Residual block described in paper"""
        d = Conv2D(filters, kernel_size=3, strides=1,
                   padding='same')(layer_input)
        d = Activation('relu')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Add()([d, layer_input])
        return d

    def deconv2d(layer_input):
        """Layers used during upsampling"""
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
        u = Activation('relu')(u)
        return u

    # Low resolution image input
    img_lr = Input(shape=lr_shape)

    # Pre-residual block
    c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
    c1 = Activation('relu')(c1)

    # Propogate through residual blocks
    r = residual_block(c1, 64)
    for _ in range(16 - 1):
        r = residual_block(r, 64)

    # Post-residual block
    c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c2 = Add()([c2, c1])

    # Upsampling
    u1 = deconv2d(c2)
    u2 = deconv2d(u1)

    # Generate high resolution output
    gen_hr = Conv2D(3, kernel_size=9, strides=1,
                    padding='same', activation='tanh')(u2)

    return Model(img_lr, gen_hr)

gen = build_generator(img_lr,lr_shape)

In [None]:
SVG(model_to_dot(gen, show_layer_names=True, show_shapes=True).create(prog='dot', format='svg'))

In [None]:
gen.summary()

# Discriminator

In [5]:
def build_discriminator(hr_shape):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides,
                       padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
        d0 = Input(shape=hr_shape)

        d1 = d_block(d0, 64, bn=False)
        d2 = d_block(d1, 64, strides=2)
        d3 = d_block(d2, 64*2)
        d4 = d_block(d3, 64*2, strides=2)
        d5 = d_block(d4, 64*4)
        d6 = d_block(d5, 64*4, strides=2)
        d7 = d_block(d6, 64*8)
        d8 = d_block(d7, 64*8, strides=2)

        d9 = Dense(64*16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)

disc = build_discriminator(hr_shape)
disc.compile(loss='mse',
             optimizer=optimizer,
             metrics=['accuracy'])

In [None]:
SVG(model_to_dot(gen, show_layer_names=True, show_shapes=True).create(prog='dot', format='svg'))

In [None]:
disc.summary()

# VGG

In [6]:
def build_vgg(hr_shape):
    """
    Builds a pre-trained VGG19 model that outputs image features extracted at the
    third block of the model
    """
    vgg = VGG19(weights="imagenet")
    # Set outputs to outputs of last conv. layer in block 3
    # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
    vgg.outputs = [vgg.layers[9].output]

    img = Input(shape=hr_shape)

    # Extract image features
    img_features = vgg(img)

    return Model(img, img_features)

vgg = build_vgg(hr_shape)
vgg.trainable = False
vgg.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

# Combined

In [7]:
gen_hr = gen(img_lr)
vgg_features = vgg(gen_hr)
validity = disc(gen_hr)
combined = Model([img_lr, img_hr], [validity, vgg_features])
combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)

In [None]:
SVG(model_to_dot(combined, show_layer_names=True, show_shapes=True).create(prog='dot', format='svg'))

# Training

In [8]:
path = "./mainDataset"
dataset_name = "train"
steps = len(glob(path+'/%s/*' % (dataset_name)))//batchSize
# for the VGG feature out true labels
patch = int(224 / 2**4)
disc_patch = (patch, patch, 1)

In [9]:
train_hr_img_list = sorted(tl.files.load_file_list(path=path+'/%s/' % (dataset_name), regx='.*.png', printable=False))

train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=path+'/%s/' % (dataset_name), n_threads=32)

[TL] read 32 from ./mainDataset/train/
[TL] read 64 from ./mainDataset/train/
[TL] read 96 from ./mainDataset/train/
[TL] read 128 from ./mainDataset/train/
[TL] read 160 from ./mainDataset/train/
[TL] read 192 from ./mainDataset/train/
[TL] read 224 from ./mainDataset/train/
[TL] read 256 from ./mainDataset/train/
[TL] read 288 from ./mainDataset/train/
[TL] read 320 from ./mainDataset/train/
[TL] read 352 from ./mainDataset/train/
[TL] read 384 from ./mainDataset/train/
[TL] read 416 from ./mainDataset/train/
[TL] read 448 from ./mainDataset/train/
[TL] read 480 from ./mainDataset/train/
[TL] read 512 from ./mainDataset/train/
[TL] read 544 from ./mainDataset/train/
[TL] read 576 from ./mainDataset/train/
[TL] read 608 from ./mainDataset/train/
[TL] read 640 from ./mainDataset/train/
[TL] read 672 from ./mainDataset/train/
[TL] read 704 from ./mainDataset/train/
[TL] read 736 from ./mainDataset/train/
[TL] read 768 from ./mainDataset/train/
[TL] read 800 from ./mainDataset/train/
[TL

In [10]:
tensorboard = TensorBoard(
  log_dir='log/pretrain_fix/run3',
  histogram_freq=0,
  batch_size=batchSize,
  write_graph=True,
  write_grads=True
)
tensorboard.set_model(combined)

In [None]:
modelSavePath = 'checkpoints/'
if not os.path.exists(modelSavePath):
    os.makedirs(modelSavePath)

modelcheckpoint = ModelCheckpoint(
    filepath = modelSavePath+"baseline.hdf5",
    monitor='g_loss',
    verbose=0,
    mode="auto",
    save_best_only=True
)
modelcheckpoint.set_model(gen)

In [None]:
#gen.load_weights("checkpoints/gen_init.h5")

In [11]:
def datagen(dev_hr_imgs,batchSize,is_testing=False):
    while(True):
        imgs_hr=[]
        imgs_lr=[]
        imgs = np.random.choice(dev_hr_imgs,batchSize)
        img_hr = tl.prepro.threading_data(imgs, fn=crop, wrg=224, hrg=224, is_random=True)
        img_lr = tl.prepro.threading_data(img_hr, fn=imresize,size=[56, 56], interp='bicubic', mode=None)
    
        imgs_hr = np.array(img_hr) / 127.5 - 1.
        imgs_lr = np.array(img_lr) / 127.5 - 1.

        yield imgs_hr, imgs_lr

In [12]:
datagenObj = datagen(train_hr_imgs,batchSize)

In [13]:
sample_hr = tl.prepro.threading_data(train_hr_imgs[0:batchSize], fn=crop, wrg=224, hrg=224, is_random=True)
sample_lr = tl.prepro.threading_data(sample_hr, fn=imresize,size=[56, 56], interp='bicubic', mode=None)

In [14]:
tl.vis.save_images(sample_hr, [int(ni), int(ni)],'images/'+dataset_name+'/sample_hr.png')
tl.vis.save_images(sample_lr, [int(ni), int(ni)],'images/'+dataset_name+'/sample_lr.png')

## Pretrain Gen

In [None]:
for epoch in range(100):
    for step in tqdm_notebook(range(0,steps)):
        start_time = datetime.datetime.now()      
        # ------------------
        #  Train Generator
        # ------------------

        # Sample images and their conditioning counterparts
        imgs_hr, imgs_lr = next(datagenObj)

        # The generators want the discriminators to label the generated images as real
        valid = np.random.uniform(low=0.6, high=1, size=((batchSize,) + disc_patch))

        # Extract ground truth image features using pre-trained VGG19 model
        image_features = vgg.predict(imgs_hr)

        # Train the generators
        g_loss = combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
        
        tb_step = step + int(epoch*(steps))
        tensorboard.on_epoch_end(tb_step, {"g_int_loss":g_loss[0]})
        
        
    if epoch % 10 == 0:
        out = gen.predict(sample_lr)
        tl.vis.save_images(out, [int(ni), int(ni)],'images/'+dataset_name+'/A_train.png')

    elapsed_time = datetime.datetime.now() - start_time
    gen.save_weights("./checkpoints/gen_init.h5")
    # Plot the progress
    print("Epoch %d time: %s" % (epoch, elapsed_time))

    # If at save interval => save generated image samples
tensorboard.on_train_end(None)

HBox(children=(IntProgress(value=0, max=182), HTML(value='')))

Instructions for updating:
Use tf.cast instead.

Epoch 0 time: 0:00:01.643863


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 1 time: 0:00:00.442170


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 2 time: 0:00:00.440416


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 3 time: 0:00:00.440872


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 4 time: 0:00:00.442805


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 5 time: 0:00:00.442854


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 6 time: 0:00:00.437036


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 7 time: 0:00:00.435544


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 8 time: 0:00:00.440912


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))


Epoch 9 time: 0:00:00.441503


HBox(children=(IntProgress(value=0, max=182), HTML(value='')))

In [None]:
tensorboard = TensorBoard(
  log_dir='log/train_fix/run4',
  histogram_freq=0,
  batch_size=batchSize,
  write_graph=True,
  write_grads=True
)
tensorboard.set_model(combined)

In [None]:
'''
SRGAN - 2and 3

'''

In [None]:
d_idx = 0
g_idx = 0
for epoch in range(epochs):
    for step in tqdm_notebook(range(0,steps)):
        start_time = datetime.datetime.now()     
        if step % 2 == 0:
            # ----------------------
            #  Train Discriminator
            # ----------------------
            d_idx += 1
            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = next(datagenObj)

            # From low res. image generate high res. version
            fake_hr = gen.predict(imgs_lr)

            valid = np.random.uniform(low=0.8, high=1, size=((batchSize,) + disc_patch))
            fake = np.random.uniform(low=0, high=0.2, size=((batchSize,) + disc_patch))

            # Train the discriminators (original images = real / generated = Fake)
            disc.trainable = True
            d_loss_real = disc.train_on_batch(imgs_hr, valid)
            d_loss_fake = disc.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            disc.trainable = False
            
            tensorboard.on_epoch_end(d_idx, {"d_loss": d_loss[0]})
            
        else:
            # ------------------
            #  Train Generator
            # ------------------
            g_idx += 1
            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = next(datagenObj)

            # The generators want the discriminators to label the generated images as real
            valid = np.random.uniform(low=0.9, high=1, size=((batchSize,) + disc_patch))

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = vgg.predict(imgs_hr)

            # Train the generators
            g_loss = combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
            tensorboard.on_epoch_end(g_idx, {"g_loss":g_loss[0]})
        
    out = gen.predict(sample_lr)
    tl.vis.save_images(out, [int(ni), int(ni)],'images/'+dataset_name+'/train_%d.png' % int(epoch*steps + step))
    if(epoch % 100 == 0):
        out = gen.predict(sample_lr)
        tl.vis.save_images(out, [int(ni), int(ni)],'images/'+dataset_name+'/train_%d.png' % int(epoch))
    elapsed_time = datetime.datetime.now() - start_time
    gen.save_weights("./checkpoints/gen.h5")
    disc.save_weights("./checkpoints/disc.h5")
    # Plot the progress
    print("Epoch %d time: %s" % (epoch, elapsed_time))
tensorboard.on_train_end(None)

In [None]:
 elapsed_time = datetime.datetime.now() - start_time

In [None]:
elapsed_time