In [None]:
from __future__ import print_function, division
import scipy
from keras.callbacks import ModelCheckpoint
from keras.datasets import mnist
#from keras_contrib.layers.normalization import Normalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

class DiscoGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'saree'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))


        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------

        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Objectives
        # + Adversarial: Fool domain discriminators
        # + Translation: Minimize MAE between e.g. fake B and true B
        # + Cycle-consistency: Minimize MAE between reconstructed images and original
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        fake_B, fake_A,
                                        reconstr_A, reconstr_B ])
        self.combined.load_weights('saved_model/model15.h5')
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                              optimizer=optimizer)
        

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, normalize=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.4)(d)
            if normalize:
                d = BatchNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, normalize=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, self.gf*8)
        u2 = deconv2d(u1, d5, self.gf*8)
        u3 = deconv2d(u2, d4, self.gf*8)
        u4 = deconv2d(u3, d3, self.gf*4)
        u5 = deconv2d(u4, d2, self.gf*2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1,
                            padding='same', activation='tanh')(u7)
        model=Model(d0, output_img)
        
        return model

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.4)(d)
            if normalization:
                d = BatchNormalization()(d)
            return d

        img = Input(shape=self.img_shape)

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(16,epochs):

            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, \
                                                                         imgs_B, imgs_A, \
                                                                         imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time
                
                # Plot the progress
                print ("[%d] [%d/%d] time: %s, [d_loss: %f, g_loss: %f]" % (epoch, batch_i,
                                                                        self.data_loader.n_batches,
                                                                        elapsed_time,
                                                                        d_loss[0], g_loss[0]))

                # If at save interval => save generated image samples
                #if batch_i % sample_interval == 0:
                #   self.sample_images(epoch, batch_i)
            
            # serialize weights to HDF5
            self.g_AB.save("saved_model/actual_model"+str(epoch)+".h5")
            self.combined.save_weights("saved_model/model"+str(epoch)+".h5")
            print("Saved model "+str(epoch)+" to disk")

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3

        imgs_A, imgs_B = self.data_loader.load_data(batch_size=1, is_testing=True)

        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

if __name__ == '__main__':
    gan = DiscoGAN()
    gan.train(epochs=18,batch_size=4, sample_interval=50)


Using TensorFlow backend.
  if issubdtype(ts, int):
  elif issubdtype(type(size), float):
  'Discrepancy between trainable weights and collected trainable'


[16] [0/435] time: 0:00:47.977190, [d_loss: 0.244784, g_loss: 0.821211]
[16] [1/435] time: 0:00:58.680112, [d_loss: 0.152257, g_loss: 1.050221]
[16] [2/435] time: 0:01:08.163687, [d_loss: 0.170006, g_loss: 1.042357]
[16] [3/435] time: 0:01:17.522893, [d_loss: 0.055963, g_loss: 1.024050]
[16] [4/435] time: 0:01:27.047800, [d_loss: 0.125756, g_loss: 1.357282]
[16] [5/435] time: 0:01:36.393678, [d_loss: 0.138142, g_loss: 1.105225]
[16] [6/435] time: 0:01:45.594870, [d_loss: 0.076915, g_loss: 0.973359]
[16] [7/435] time: 0:01:55.082215, [d_loss: 0.056366, g_loss: 0.835286]
[16] [8/435] time: 0:02:04.478982, [d_loss: 0.096328, g_loss: 0.837910]
[16] [9/435] time: 0:02:13.854483, [d_loss: 0.022641, g_loss: 0.917081]
[16] [10/435] time: 0:02:23.250118, [d_loss: 0.027821, g_loss: 1.157805]
[16] [11/435] time: 0:02:32.609037, [d_loss: 0.162122, g_loss: 1.029477]
[16] [12/435] time: 0:02:42.044007, [d_loss: 0.083148, g_loss: 1.145879]
[16] [13/435] time: 0:02:51.430064, [d_loss: 0.118210, g_loss

[16] [113/435] time: 0:18:18.182783, [d_loss: 0.014027, g_loss: 1.060748]
[16] [114/435] time: 0:18:27.467522, [d_loss: 0.019584, g_loss: 0.992877]
[16] [115/435] time: 0:18:36.624652, [d_loss: 0.025039, g_loss: 1.189870]
[16] [116/435] time: 0:18:45.781803, [d_loss: 0.027457, g_loss: 0.754089]
[16] [117/435] time: 0:18:55.109750, [d_loss: 0.028017, g_loss: 1.021206]
[16] [118/435] time: 0:19:04.371543, [d_loss: 0.025942, g_loss: 1.187881]
[16] [119/435] time: 0:19:13.610505, [d_loss: 0.021541, g_loss: 1.002033]
[16] [120/435] time: 0:19:22.917719, [d_loss: 0.047847, g_loss: 1.303359]
[16] [121/435] time: 0:19:32.110601, [d_loss: 0.081367, g_loss: 1.179337]
[16] [122/435] time: 0:19:41.365558, [d_loss: 0.038624, g_loss: 1.084696]
[16] [123/435] time: 0:19:50.593163, [d_loss: 0.027862, g_loss: 1.089847]
[16] [124/435] time: 0:19:59.802857, [d_loss: 0.029035, g_loss: 0.993284]
[16] [125/435] time: 0:20:09.071295, [d_loss: 0.028007, g_loss: 1.097177]
[16] [126/435] time: 0:20:18.281062, [

[16] [224/435] time: 0:35:27.160647, [d_loss: 0.048327, g_loss: 0.956988]
[16] [225/435] time: 0:35:36.144907, [d_loss: 0.058790, g_loss: 0.919583]
[16] [226/435] time: 0:35:45.461189, [d_loss: 0.059516, g_loss: 0.776746]
[16] [227/435] time: 0:35:54.694307, [d_loss: 0.011918, g_loss: 0.840529]
[16] [228/435] time: 0:36:03.888759, [d_loss: 0.014115, g_loss: 0.825762]
[16] [229/435] time: 0:36:13.196127, [d_loss: 0.022120, g_loss: 0.863413]
[16] [230/435] time: 0:36:22.580772, [d_loss: 0.011610, g_loss: 0.809491]
[16] [231/435] time: 0:36:31.907656, [d_loss: 0.010837, g_loss: 0.856619]
