In [None]:
from __future__ import print_function, division
## python libs
import os
import numpy as np
## tf-Keras libs
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.optimizers import Adam
from keras.layers import Input, Dropout, Concatenate
#from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, MaxPooling2D
from keras.applications import vgg19

In [None]:
def VGG19_Content(dataset='imagenet'):
    # Load VGG, trained on imagenet data
    vgg = vgg19.VGG19(include_top=False, weights=dataset)
    vgg.trainable = False
    content_layers = ['block5_conv2']
    content_outputs = [vgg.get_layer(name).output for name in content_layers]
    return Model(vgg.input, content_outputs)

In [None]:
class FUNIE_GAN():
    def __init__(self, imrow=256, imcol=256, imchan=3, loss_meth='wgan'):
        ## input image shape
        self.img_rows, self.img_cols, self.channels = imrow, imcol, imchan
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        ## input images and their conditioning images
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        ## conv 5_2 content from vgg19 network
        self.vgg_content = VGG19_Content()
        ## output shape of D (patchGAN)
        self.disc_patch = (16, 16, 1)
        ## number of filters in the first layer of G and D
        self.gf, self.df = 32, 32
        optimizer = Adam(0.0003, 0.5)
        ## Build and compile the discriminator
        self.discriminator = self.FUNIE_discriminator()
        self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        ## Build the generator
        self.generator = self.FUNIE_generator2()
        ## By conditioning on B generate a fake version of A
        fake_A = self.generator(img_B)
        ## For the combined model we will only train the generator
        self.discriminator.trainable = False
        ## Discriminators determines validity of translated images / condition pairs
        valid = self.discriminator([fake_A, img_B])
        ## compute the comboned loss
        self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
        self.combined.compile(loss=['mse', self.total_gen_loss], loss_weights=[0.2, 0.8], optimizer=optimizer)
        
    def wasserstein_loss(self, y_true, y_pred):
        # for wasserstein GAN loss
        return K.mean(y_true * y_pred)
    

    def perceptual_distance(self, y_true, y_pred):
        """
           Calculating perceptual distance
           Thanks to github.com/wandb/superres
        """
        y_true = (y_true+1.0)*127.5 # [-1,1] -> [0, 255]
        y_pred = (y_pred+1.0)*127.5 # [-1,1] -> [0, 255]
        rmean = (y_true[:, :, :, 0] + y_pred[:, :, :, 0]) / 2
        r = y_true[:, :, :, 0] - y_pred[:, :, :, 0]
        g = y_true[:, :, :, 1] - y_pred[:, :, :, 1]
        b = y_true[:, :, :, 2] - y_pred[:, :, :, 2]
        return K.mean(K.sqrt((((512+rmean)*r*r)/256) + 4*g*g + (((767-rmean)*b*b)/256)))


    def total_gen_loss(self, org_content, gen_content):
        # custom perceptual loss function
        vgg_org_content = self.vgg_content(org_content)
        vgg_gen_content = self.vgg_content(gen_content)
        content_loss = K.mean(K.square(vgg_org_content - vgg_gen_content), axis=-1)
        mae_gen_loss = K.mean(K.abs(org_content-gen_content))
        perceptual_loss = self.perceptual_distance(org_content, gen_content)
        gen_total_err = 0.7*mae_gen_loss+0.3*content_loss # v1
        # updated loss function in v2
        #gen_total_err = 0.6*mae_gen_loss+0.3*content_loss+0.1*perceptual_loss
        return gen_total_err


    def FUNIE_generator1(self):
        """
           Inspired by the U-Net Generator with skip connections
           This is a much simpler architecture with fewer parameters (faster inference)
        """
        def conv2d(layer_input, filters, f_size=3, bn=True):
            ## for downsampling
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            #d = LeakyReLU(alpha=0.2)(d)
            d = Activation('relu')(d)
            if bn: d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=3, dropout_rate=0):
            ## for upsampling
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate: u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u
        ## input
        d0 = Input(shape=self.img_shape); print(d0)
        ## downsample
        d1 = conv2d(d0, self.gf*1, f_size=5, bn=False)
        d2 = conv2d(d1, self.gf*4, f_size=4, bn=True)
        d3 = conv2d(d2, self.gf*8, f_size=4, bn=True)
        d4 = conv2d(d3, self.gf*8, f_size=3, bn=True)
        d5 = conv2d(d4, self.gf*8, f_size=3, bn=True)
        ## upsample
        u1 = deconv2d(d5, d4, self.gf*8)
        u2 = deconv2d(u1, d3, self.gf*8)
        u3 = deconv2d(u2, d2, self.gf*4)
        u4 = deconv2d(u3, d1, self.gf*1)
        u5 = UpSampling2D(size=2)(u4)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u5)
        print(output_img); print();

        return Model(d0, output_img)


    def FUNIE_generator2(self):
        """
           Inspired by the U-Net Generator with skip connections
           This is a much simpler architecture with fewer parameters (faster inference)
        """
        def conv2d(layer_input, filters, f_size=3, bn=True):
            ## for downsampling
            d = Conv2D(filters, kernel_size=f_size, padding='same')(layer_input)
            #d = LeakyReLU(alpha=0.2)(d)
            d = Activation('relu')(d)
            if bn: d = BatchNormalization(momentum=0.75)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=3, dropout_rate=0):
            ## for upsampling
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate: u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u
        ## input
        d0 = Input(shape=self.img_shape); print(d0)
        ## downsample
        d1 = conv2d(d0, self.gf*1, f_size=5, bn=False)
        d1a = MaxPooling2D(pool_size=(2, 2))(d1)
        d2 = conv2d(d1a, self.gf*2, f_size=4, bn=True)
        d3 = conv2d(d2, self.gf*2, f_size=4, bn=True)
        d3a = MaxPooling2D(pool_size=(2, 2))(d3)
        d4 = conv2d(d3a, self.gf*4, f_size=3, bn=True)
        d5 = conv2d(d4, self.gf*4, f_size=3, bn=True)
        d5a = MaxPooling2D(pool_size=(2, 2))(d5)
        d6 = conv2d(d5a, self.gf*8, f_size=3, bn=True)
        ## upsample
        u1 = deconv2d(d6, d5, self.gf*8)
        u2 = deconv2d(u1, d3, self.gf*8)
        u3 = deconv2d(u2, d1, self.gf*4)
        u4 = conv2d(u3, self.gf*4, f_size=3)
        u5 = conv2d(u4, self.gf*8, f_size=3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u5)
        return Model(d0, output_img)



    def FUNIE_discriminator(self):
        """
           Inspired by the pix2pix discriminator
        """
        def d_layer(layer_input, filters, strides_=2,f_size=3, bn=True):
            ## Discriminator layers
            d = Conv2D(filters, kernel_size=f_size, strides=strides_, padding='same')(layer_input)
            #d = LeakyReLU(alpha=0.2)(d)
            d = Activation('relu')(d)
            if bn: d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        ## input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])
        ## Discriminator layers
        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4) 
        d4 = d_layer(d3, self.df*8)
        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
        # return model
        return Model([img_A, img_B], validity)

In [None]:
"""
# > Various modules for handling data 
# > Maintainer: https://github.com/xahidbuffon
"""
from __future__ import division
from __future__ import absolute_import
import os
import random
import cv2
import fnmatch
import numpy as np
from scipy import misc

def deprocess(x):
    # [-1,1] -> [0, 255]
    return (x+1.0)*127.5

def preprocess(x):
    # [0,255] -> [-1, 1]
    return (x/127.5)-1.0

def augment(a_img, b_img):
    """
       Augment images - a is distorted
    """
    # randomly interpolate
    a = random.random()
    a_img = a_img*(1-a) + b_img*a
    # flip image left right
    if (random.random() < 0.25):
        a_img = np.fliplr(a_img)
        b_img = np.fliplr(b_img)
    # flip image up down
    if (random.random() < 0.25):
        a_img = np.flipud(a_img)
        b_img = np.flipud(b_img) 
    return a_img, b_img

def getPaths(data_dir):
    exts = ['*.png','*.PNG','*.jpg','*.JPG', '*.JPEG']
    image_paths = []
    for pattern in exts:
        for d, s, fList in os.walk(data_dir):
            for filename in fList:
                if (fnmatch.fnmatch(filename, pattern)):
                    fname_ = os.path.join(d,filename)
                    image_paths.append(fname_)
    return np.asarray(image_paths)

def read_and_resize(path, img_res):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float)
    img = cv2.resize(img, img_res)
    return img

def read_and_resize_pair(pathA, pathB, img_res): 
    img_A = cv2.imread(pathA)
    img_A = cv2.cvtColor(img_A, cv2.COLOR_BGR2RGB).astype(np.float)
    img_A = cv2.resize(img_A, img_res)
    img_B = cv2.imread(pathB)
    img_B = cv2.cvtColor(img_B, cv2.COLOR_BGR2RGB).astype(np.float)
    img_B = cv2.resize(img_B, img_res)
    return img_A, img_B

def get_local_test_data(data_dir, img_res=(256, 256)):
    assert os.path.exists(data_dir), "local image path doesnt exist"
    imgs = []
    for p in getPaths(data_dir):
        img = read_and_resize(p, img_res)
        imgs.append(img)
    imgs = preprocess(np.array(imgs))
    return imgs

class DataLoader():
    def __init__(self, data_dir, dataset_name, img_res=(256, 256), test_only=False):
        self.img_res = img_res
        self.DATA = dataset_name
        self.data_dir = data_dir
        if not test_only:
            self.trainA_paths = getPaths(os.path.join(self.data_dir, "trainA")) # distorted
            self.trainB_paths = getPaths(os.path.join(self.data_dir, "trainB")) # enhanced
            if (len(self.trainA_paths)<len(self.trainB_paths)):
                self.trainB_paths = self.trainB_paths[:len(self.trainA_paths)]
            elif (len(self.trainA_paths)>len(self.trainB_paths)):
                self.trainA_paths = self.trainA_paths[:len(self.trainB_paths)]
            else: pass
            self.val_paths = getPaths(os.path.join(self.data_dir, "validation"))
            self.num_train, self.num_val = len(self.trainA_paths), len(self.val_paths)
            print ("{0} training pairs\n".format(self.num_train))
        else:
            self.test_paths    = getPaths(os.path.join(self.data_dir, "test"))
            print ("{0} test images\n".format(len(self.test_paths)))

    def get_test_data(self, batch_size=1):
        idx = np.random.choice(np.arange(len(self.test_paths)), batch_size, replace=False)
        paths = self.test_paths[idx]
        imgs = []
        for p in paths:
            img = read_and_resize(p, self.img_res)
            imgs.append(img)
        imgs = preprocess(np.array(imgs))
        return imgs

    def load_val_data(self, batch_size=1):
        idx = np.random.choice(np.arange(self.num_val), batch_size, replace=False)
        pathsA = self.trainA_paths[idx]
        pathsB = self.trainB_paths[idx]
        imgs_A, imgs_B = [], []
        for idx in range(len(pathsB)):
            img_A, img_B = read_and_resize_pair(pathsA[idx], pathsB[idx], self.img_res)
            imgs_A.append(img_A)
            imgs_B.append(img_B)
        imgs_A = preprocess(np.array(imgs_A))
        imgs_B = preprocess(np.array(imgs_B))
        return imgs_A, imgs_B

    def load_batch(self, batch_size=1, data_augment=True):
        self.n_batches = self.num_train//batch_size
        for i in range(self.n_batches-1):
            batch_A = self.trainA_paths[i*batch_size:(i+1)*batch_size]
            batch_B = self.trainB_paths[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for idx in range(len(batch_A)): 
                img_A, img_B = read_and_resize_pair(batch_A[idx], batch_B[idx], self.img_res)
                if (data_augment):
                    img_A, img_B = augment(img_A, img_B)
                imgs_A.append(img_A)
                imgs_B.append(img_B)
            imgs_A = preprocess(np.array(imgs_A))
            imgs_B = preprocess(np.array(imgs_B))
            yield imgs_A, imgs_B




In [None]:
data_dir = "../input/evprdata/Paired/"
dataset_name = "underwater_imagenet" # options: {'underwater_imagenet', 'underwater_dark'}
data_loader = DataLoader(os.path.join(data_dir, dataset_name), dataset_name)
## create dir for log and (sampled) validation data
samples_dir = os.path.join("./samples/funieGAN/", dataset_name)
checkpoint_dir = os.path.join("./checkpoints/funieGAN/", dataset_name)
if not os.path.exists(samples_dir): os.makedirs(samples_dir)
if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir)


In [None]:
num_epoch = 7
batch_size = 4
val_interval = 200
N_val_samples = 3
save_model_interval = data_loader.num_train//batch_size
num_step = num_epoch*save_model_interval

In [None]:
save_model_interval, data_loader.num_train, num_step

In [None]:
## load model arch
funie_gan = FUNIE_GAN()
## ground-truths for adversarial loss
valid = np.ones((batch_size,) + funie_gan.disc_patch)
fake = np.zeros((batch_size,) + funie_gan.disc_patch)

In [None]:
from __future__ import division
from __future__ import absolute_import
import os
import random
import cv2
import fnmatch
import numpy as np
from scipy import misc

In [None]:
import matplotlib.pyplot as plt

def save_val_samples_funieGAN(samples_dir, gen_imgs, step, N_samples=3, N_ims=3):
    row=N_samples; col=N_ims;
    titles = ['Input', 'Generated', 'Original']
    fig, axs = plt.subplots(row, col)
    cnt = 0
    for j in range(col):
        for i in range(row): 
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(os.path.join(samples_dir, ("%d.png" %step)))
    plt.close()


def save_val_samples_unpaired(samples_dir, gen_imgs, step, N_samples=1, N_ims=6):
    row=2*N_samples; col=N_ims//2;
    titles = ['Original','Translated','Reconstructed']
    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col): 
            axs[i,j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[j])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig(os.path.join(samples_dir, ("_%d.png" %step)))
    plt.close()


def save_test_samples_funieGAN(samples_dir, gen_imgs, step=0):
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(gen_imgs[0])
    axs[0].set_title("Input")
    axs[0].axis('off')
    axs[1].imshow(gen_imgs[1])
    axs[1].set_title("Generated")
    axs[1].axis('off')
    fig.savefig(os.path.join(samples_dir,("_test_%d.png" %step)))
    plt.close()


def viz_gen_and_dis_losses(all_D_losses, all_G_losses, save_dir=None):
    plt.plot(all_D_losses, 'r')
    plt.plot(all_G_losses, 'g')
    plt.title('Model convergence'); plt.ylabel('Losses'); plt.xlabel('# of steps');
    plt.legend(['Discriminator network', 'Generator network'], loc='upper right')
    plt.show();
    if not save_dir:
        plt.savefig(os.path.join(save_dir, '_conv.png'))



In [None]:
## training loop
step = 0
all_D_losses = []; all_G_losses = []
while (step <= num_step):
    for _, (imgs_distorted, imgs_good) in enumerate(data_loader.load_batch(batch_size)):
        ##  train the discriminator
        imgs_fake = funie_gan.generator.predict(imgs_distorted)
        d_loss_real = funie_gan.discriminator.train_on_batch([imgs_good, imgs_distorted], valid)
        d_loss_fake = funie_gan.discriminator.train_on_batch([imgs_fake, imgs_distorted], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        ## train the generator
        g_loss = funie_gan.combined.train_on_batch([imgs_good, imgs_distorted], [valid, imgs_good])
        ## increment step, save losses, and print them 
        step += 1; all_D_losses.append(d_loss[0]);  all_G_losses.append(g_loss[0]);
        if step%50==0:
            print ("Step {0}/{1}: lossD: {2}, lossG: {3}".format(step, num_step, d_loss[0], g_loss[0])) 
        ## validate and save generated samples at regular intervals 
        if (step % val_interval==0):
            imgs_distorted, imgs_good = data_loader.load_val_data(batch_size=N_val_samples)
            imgs_fake = funie_gan.generator.predict(imgs_distorted)
            gen_imgs = np.concatenate([imgs_distorted, imgs_fake, imgs_good])
            gen_imgs = 0.5 * gen_imgs + 0.5 # Rescale to 0-1
            save_val_samples_funieGAN(samples_dir, gen_imgs, step, N_samples=N_val_samples)
            titles = ['Input', 'Generated', 'Original']
            fig, axs = plt.subplots(3, 3)
            cnt = 0
            for j in range(3):
                for i in range(N_val_samples): 
                    axs[i,j].imshow(gen_imgs[cnt])
                    axs[i, j].set_title(titles[j])
                    axs[i,j].axis('off')
                    cnt += 1
            plt.show()
        ## save model and weights
        if (step % save_model_interval==0):
            model_name = os.path.join(checkpoint_dir, ("model_%d" %step))
            with open(model_name+"_.json", "w") as json_file:
                json_file.write(funie_gan.generator.to_json())
            funie_gan.generator.save_weights(model_name+"_.h5")
            print("\nSaved trained model in {0}\n".format(checkpoint_dir))
        ## sanity
        if (step>=num_step): break

In [None]:
gen_imgs

In [None]:
fig, axs = plt.subplots(1, 2)
axs[0].imshow(gen_imgs[2])
axs[0].set_title("Input")
axs[0].axis('off')
axs[1].imshow(gen_imgs[2])
axs[1].set_title("Generated")
axs[1].axis('off')

In [None]:
row=3
col=3
fig, axs = plt.subplots(row, col)
cnt = 0
for j in range(col):
    for i in range(row): 
        axs[i,j].imshow(gen_imgs[cnt])
        axs[i,j].axis('off')
        cnt += 1