##Importing necessary libraries

In [32]:
from __future__ import print_function, division
import cv2
from keras.layers import LayerNormalization, Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, Activation, ZeroPadding2D, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread
import sys
import numpy as np
import os
from glob import glob

## Downloading training and testing data

In [33]:
!python download_data.py summer2winter_yosemite

## Defining functions

In [53]:
def im_read(path):
        return imread(path).astype(float)

def load_data(data_dir, domain, batch_size=1,  img_shape = (128,128), is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/%s/*' % (data_dir, data_dir, data_type))

        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = im_read(img_path)
            if not is_testing:
                img = cv2.resize(img, img_shape)

                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
               img = cv2.resize(img, img_shape)
            imgs.append(img)

        imgs = np.array(imgs)/127.5 - 1.

        return imgs


def load_img(path,img_shape = (128,128)):
        img = im_read(path)
        img = cv2.resize(img, img_shape)
        img = img/127.5 - 1.
        return img[np.newaxis, :, :, :]

def load_batch(data_dir, batch_size=1, is_testing=False,img_shape = (128,128)):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%s/%sA/*' % (data_dir, data_dir, data_type))
        path_B = glob('./datasets/%s/%s/%sB/*' % (data_dir, data_dir, data_type))

        n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = n_batches * batch_size

        # Sample n_batches * batch_size from each path list so that model sees all
        # samples from both domains
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = im_read(img_A)
                img_B = im_read(img_B)

                img_A = cv2.resize(img_A, img_shape)
                img_B = cv2.resize(img_B, img_shape)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

In [54]:
def build_discriminator(img_shape = (128,128,3)):
        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""           
           
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = LayerNormalization()(d)
            return d

        img = Input(shape=img_shape)
        df = 64
        d1 = d_layer(img, df, normalization=False)
        d2 = d_layer(d1, df*2)
        d3 = d_layer(d2, df*4)
        d4 = d_layer(d3, df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model(img, validity)

In [55]:
def build_generator():
        """U-Net Generator"""
        gf = 32
        channels = 3
        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = LayerNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = LayerNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u
        img_shape = (128,128,3)
        # Image input
        d0 = Input(shape=img_shape)

        # Downsampling
        d1 = conv2d(d0, gf)
        d2 = conv2d(d1, gf*2)
        d3 = conv2d(d2, gf*4)
        d4 = conv2d(d3, gf*8)

        # Upsampling
        u1 = deconv2d(d4, d3, gf*4)
        u2 = deconv2d(u1, d2, gf*2)
        u3 = deconv2d(u2, d1, gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

In [56]:
class CycleGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.data_dir = 'summer2winter_yosemite'
        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Loss weights
        self.lambda_cycle = 10.0                    # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle    # Identity loss

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.d_A = build_discriminator()
        self.d_B = build_discriminator()
        self.d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------

        # Build the generators
        self.g_AB = build_generator()
        self.g_BA = build_generator()

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(load_batch(self.data_dir,batch_size)):
                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)


                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                print ("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s "
                                                                       % ( epoch, epochs,
                                                                    
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.data_dir, exist_ok=True)
        r, c = 2, 3

        imgs_A = load_data(self.data_dir, domain="A", batch_size=1, is_testing=True)
        imgs_B = load_data(self.data_dir, domain="B", batch_size=1, is_testing=True)

        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.data_dir, epoch, batch_i))
        plt.close()

In [58]:
gan = CycleGAN()

In [61]:
gan.train(epochs=1, batch_size=64, sample_interval=200)

[Epoch 0/1] [D loss: 0.495799, acc:  53%] [G loss: 13.706840, adv: 0.901396, recon: 0.532270, id: 0.638069] time: 0:00:02.845821 
[Epoch 0/1] [D loss: 0.432235, acc:  56%] [G loss: 13.969219, adv: 0.786765, recon: 0.557389, id: 0.594420] time: 0:00:06.712230 
[Epoch 0/1] [D loss: 0.409222, acc:  52%] [G loss: 12.994154, adv: 0.807697, recon: 0.511129, id: 0.530496] time: 0:00:09.430221 
[Epoch 0/1] [D loss: 0.409441, acc:  54%] [G loss: 12.748266, adv: 0.833375, recon: 0.496452, id: 0.541050] time: 0:00:12.208774 
[Epoch 0/1] [D loss: 0.377452, acc:  55%] [G loss: 12.601738, adv: 0.819233, recon: 0.488743, id: 0.591999] time: 0:00:14.954661 
[Epoch 0/1] [D loss: 0.322129, acc:  65%] [G loss: 12.272699, adv: 0.967345, recon: 0.457180, id: 0.580169] time: 0:00:18.736880 
[Epoch 0/1] [D loss: 0.329408, acc:  62%] [G loss: 11.399852, adv: 0.897283, recon: 0.424652, id: 0.514697] time: 0:00:21.484901 
[Epoch 0/1] [D loss: 0.319090, acc:  62%] [G loss: 11.260750, adv: 1.036484, recon: 0.4048

In [62]:
gan.sample_images(100,64)

