In [1]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop

import keras.backend as K
#from keras.applications.vgg16 import VGG16

import matplotlib.pyplot as plt
%matplotlib inline
import sys

import numpy as np

class WGAN():
    def __init__(self):
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        #self.latent_dim = 100

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        # Build and compile the critic
        self.critic = self.build_critic_map()
        #self.critic.compile(loss = self.wasserstein_loss,
        #    optimizer=optimizer,
        #    metrics=['accuracy'])
        
        self.critic.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator_map()

        # The generator takes noise as input and generated imgs
        z = Input(shape = self.img_shape)
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.critic.trainable = False

        # The critic takes generated images as input and determines validity
        valid = self.critic(img)
        
        
        # The combined model  (stacked generator and critic): 
        # In this case it just test if the input vector valid and not consider the target
        #self.combined = Model(z, valid)
        #self.combined.compile(loss=self.wasserstein_loss,
        #    optimizer=optimizer,
        #    metrics=['accuracy'])
        

        #loss=['mse', self.wasserstein_loss]
        #loss_weights = [99, 1]
        
        loss=['mse', 'mae']
        loss_weights=[100, 1]
        
        # The combined model  (stacked generator and critic)
        self.combined = Model(z, [img, valid])
        self.combined.compile(loss=loss, 
                              loss_weights=loss_weights,
                              optimizer=optimizer,
                              metrics=['accuracy'])

    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_generator_map(self):

        model = Sequential()

        # Encoder

        # Down sample
        model.add(Conv2D(48, kernel_size=(5, 5), strides=(2, 2), padding="same", 
                         input_shape=self.img_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        # Down sample
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        # Down sample
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(2, 2), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(1024, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(512, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.5))
        
        # Decoder
        
        # Up sample
        model.add(UpSampling2D(size=(2, 2)))
        model.add(Conv2D(256, kernel_size=(4, 4), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(256, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        # Up sample
        model.add(UpSampling2D(size=(2, 2)))
        model.add(Conv2D(128, kernel_size=(4, 4), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(48, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        # Up sample
        model.add(UpSampling2D(size=(2, 2)))
        model.add(Conv2D(48, kernel_size=(4, 4), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(24, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(1, kernel_size=(3, 3), strides=(1, 1), padding="same"))
        model.add(Activation('tanh')) # For range -1 - 1

        model.summary()
        
        X = Input(shape=self.img_shape)
        yhat = model(X)

        return Model(X, yhat)
    
    
    def build_critic_map(self):

        model = Sequential()

        model.add(Conv2D(16, kernel_size=5, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        #model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dropout(0.25))
        
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dropout(0.25))
        
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dropout(0.5))

        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dropout(0.5))
        
        model.add(Flatten())
        model.add(Dense(1, activation='tanh'))
        
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, data, epochs, batch_size=128, sample_interval=50):

        # # Load the dataset
        # X_train, y_train, X_test, y_test = data
        
        # Rescale -1 to 1
        X_train, y_train ,X_test, y_test = [tranform(ds) for ds in data]

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake = np.ones((batch_size, 1))

        for epoch in range(epochs):
            
            print("Epoch", epoch)
            for _ in range(self.n_critic):
            
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                
                batch_X_train = X_train[idx]
                batch_y_train = y_train[idx]

                # Generate a batch of new images
                gen_imgs = self.generator.predict(batch_X_train)

                # Train the critic
                d_loss_real = self.critic.train_on_batch(batch_y_train, valid)
                d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                # Clip critic weights
                for l in self.critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)


            # ---------------------
            #  Train Generator
            # ---------------------

            #g_loss = self.combined.train_on_batch(batch_X_train, valid) # The case only consider if input is valid
            g_loss = self.combined.train_on_batch(batch_X_train, [batch_y_train, valid])      
            
            # Plot the progress
            #print ("%d [D loss: %f] [D acc: %f] [G loss: %f] [G acc: %f]" % (epoch, 
            #                                                                 1 - d_loss[0], d_loss[1], 
            #                                                                 1 - g_loss[0], g_loss[1]))
            
            # Plot the progress
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, 
                                                                           d_loss[0], 100*d_loss[1], 
                                                                           g_loss[0], g_loss[1]))

            
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                
                idx = [278, 451, 3263, 2654]
                img_in, img_out = X_test[idx], y_test[idx]
                self.sample_maps(epoch, img_in, img_out)
                
                
    def predict(self, img_in):
        
        gen_missing = self.generator.predict(img_in)
        
        return gen_missing
    
    def sample_maps(self, epoch, img_in, img_out):

        r, c = 4, 4
        
        gen_missing = self.generator.predict(img_in)

        imgs = 0.5 * img_in + 0.5
        masked_imgs = 0.5 * img_in + 0.5
        gen_missing = 0.5 * gen_missing + 0.5

        fig, axs = plt.subplots(r, c, figsize=(15,15))
        for i in range(c):
            axs[0,i].imshow(np.reshape(img_in[i, :,:], (self.img_rows, self.img_cols)))
            axs[0,i].axis('off')
            axs[1,i].imshow(np.reshape(img_out[i, :,:], (self.img_rows, self.img_cols)))
            axs[1,i].axis('off')
            axs[2,i].imshow(np.reshape(gen_missing[i, :,:], (self.img_rows, self.img_cols)))
            axs[2,i].axis('off')
            axs[3,i].imshow(np.reshape(gen_missing[i, :,:] > 0.2, (self.img_rows, self.img_cols)))
            axs[3,i].axis('off')
            
        fig.suptitle(str(epoch), fontsize=20)
        fig.savefig("images2/%d.png" % epoch)
        plt.close()

Using TensorFlow backend.


In [2]:
def tranform(data):
    return (data - 0.5) * 2

def load_dataset():
    
    scale = 25
    p_size_1 = 128 # Compared with 256, which larger may generate round corners
    trainPath = r"../tmp_data/data_feng/geb" + str(scale) +  "/"

    x_train_sim = np.load(trainPath + "x_train_sim.npy")
    y_train_sim = np.load(trainPath + "y_train_sim.npy")
    x_test_sim = np.load(trainPath + "x_test_sim.npy")
    y_test_sim = np.load(trainPath + "y_test_sim.npy")
    
    return x_train_sim, y_train_sim, x_test_sim, y_test_sim

In [None]:
if __name__ == '__main__':
    
    data = load_dataset()

In [3]:
    wgan = WGAN()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 64, 64, 16)        416       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 16)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 64, 64, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 32)        4640      
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 32)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 32, 32, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 32, 32, 32)        0         
__________

In [None]:
    wgan.train(data, epochs=4000, batch_size=32, sample_interval=50)