In [1]:
import tensorflow as tf
import keras
from keras.layers import LeakyReLU, BatchNormalization, Input, Activation, Concatenate
import numpy as np
from keras.initializers import RandomNormal
from numpy.random import randint
from matplotlib import pyplot

In [2]:
imgs_shape = (32,32,3)

In [3]:
# Discriminator

def discriminator():
  init = RandomNormal(stddev=0.02)
  in_src_image = Input(shape=imgs_shape)
  in_target_image = Input(shape=imgs_shape)
  merged = Concatenate()([in_src_image, in_target_image])

  d = tf.keras.layers.Conv2D(32, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
  d = LeakyReLU(alpha=0.2)(d)
  
  d = tf.keras.layers.Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
  d = BatchNormalization()(d)
  d = LeakyReLU(alpha=0.2)(d)

  d = tf.keras.layers.Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
  d = BatchNormalization()(d)
  d = LeakyReLU(alpha=0.2)(d)

  d = tf.keras.layers.Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
  d = BatchNormalization()(d)
  d = LeakyReLU(alpha=0.2)(d)

  d = tf.keras.layers.Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
  patch_out = Activation('sigmoid')(d)
  
  model = keras.models.Model([in_src_image, in_target_image], patch_out)

  opt = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
  model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
  return model

In [4]:
# Generator

def generator():
    X = Input(shape = imgs_shape)
  
    conv1 = tf.keras.layers.Conv2D( 32 , kernel_size=( 5 , 5 ) , strides=1 )( X )
    conv1 = tf.keras.layers.LeakyReLU()( conv1 )
    conv1 = tf.keras.layers.Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1)( conv1 )
    conv1 = tf.keras.layers.LeakyReLU()( conv1 )

    conv2 = tf.keras.layers.Conv2D( 64 , kernel_size=( 5 , 5 ) , strides=1)( conv1 )
    conv2 = tf.keras.layers.LeakyReLU()( conv2 )
    conv2 = tf.keras.layers.Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( conv2 )
    conv2 = tf.keras.layers.LeakyReLU()( conv2 )

    conv3 = tf.keras.layers.Conv2D( 128 , kernel_size=( 5 , 5 ) , strides=1 )( conv2 )
    conv3 = tf.keras.layers.LeakyReLU()( conv3 )
    conv3 = tf.keras.layers.Conv2D( 256 , kernel_size=( 3 , 3 ) , strides=1 )( conv3 )
    conv3 = tf.keras.layers.LeakyReLU()( conv3 )

    bottleneck = tf.keras.layers.Conv2D( 256 , kernel_size=( 3 , 3 ) , strides=1 , activation='tanh' , padding='same' )( conv3 )

    concat_1 = tf.keras.layers.Concatenate()( [ bottleneck , conv3 ] )
    conv_up_3 = tf.keras.layers.Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( concat_1 )
    conv_up_3 = tf.keras.layers.Conv2DTranspose( 64 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu' )( conv_up_3 )

    concat_2 = tf.keras.layers.Concatenate()( [ conv_up_3 , conv2 ] )
    conv_up_2 = tf.keras.layers.Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( concat_2 )
    conv_up_2 = tf.keras.layers.Conv2DTranspose( 32 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu' )( conv_up_2 )

    concat_3 = tf.keras.layers.Concatenate()( [ conv_up_2 , conv1 ] )
    conv_up_1 = tf.keras.layers.Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu')( concat_3 )
    conv_up_1 = tf.keras.layers.Conv2DTranspose( 3 , kernel_size=( 5 , 5 ) , strides=1 , activation='relu')( conv_up_1 )
  
    model = tf.keras.models.Model( X , conv_up_1 )
    return model


In [5]:
# GAN model, combined discriminator and Generator
def gan_model(generator, discriminator, input_img):
  discriminator.trainable = False
  src_input = keras.layers.Input(shape = input_img)
  gen_output = generator(src_input)
  disc_output = discriminator([src_input, gen_output])
  model = keras.models.Model(inputs=src_input, outputs=[disc_output, gen_output])
  opt = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
  model.compile(loss=["binary_crossentropy", "mae"], optimizer=opt, loss_weights=[1,100])
  return model

In [6]:
def load_dataset(filename):
  data = np.load(filename)
  X1, X2 = data["arr_0"], data["arr_1"]
  X1 = (X1 - 127.5)/127.5
  X2 = (X2 - 127.5)/127.5
  return [X1,X2]

In [7]:
def generate_real_images(dataset, n, patch_shape):
  trainA, trainB = dataset
  ix = randint(0, trainA.shape[0],n)
  X1, X2 = trainA[ix], trainB[ix]
  y = np.ones((n, patch_shape, patch_shape,1))
  return [X1,X2], y

In [8]:
def generate_fake_images(model, sample, patch_shape):
  X = model.predict(sample)
  y = np.zeros((len(X), patch_shape, patch_shape, 1))
  return X,y

In [9]:
def performance_check(step, model, dataset, n=3):
  [realA, realB], _ = generate_real_images(dataset, n, 1)
  fakeB, _ = generate_fake_images(model, realA, 1)
  realA = (realA+1)/2.0
  realB = (realB+1)/2.0
  fakeB = (fakeB+1)/2.0

# plot real source images
  for i in range(n):
    pyplot.subplot(3, n, 1+i)
    pyplot.axis("off")
    pyplot.imshow(realA[i])
  
#plot generated target images  
  for i in range(n):
    pyplot.subplot(3, n, 1+n+i)
    pyplot.axis("off")
    pyplot.imshow(fakeB[i])

# plot real target images
  for i in range(n):
    pyplot.subplot(3, n, 1+n*2+i)
    pyplot.axis("off")
    pyplot.imshow(realB[i])
  filename1 = 'plot_%06d.png' % (step+1)
  pyplot.savefig(filename1)
  pyplot.close()
  filename2 = 'model_%6d.h5' % (step+1)
  model.save_weights(filename2)
  print(">saved: %s and %s" % (filename1, filename2))

In [None]:
# GAN training function

def train_gan(disc_model, gen_model, gan_model, dataset, epochs=150, batch=64):
  n_patch = disc_model.output_shape[1]
  trainA, trainB = dataset
  batch_per_epoch = int(len(trainA)/batch)  #400
  print(batch_per_epoch)
  steps = batch_per_epoch*epochs   #40000
  
  for i in range(steps):
    print(i+1)
    [realA, realB], y_real = generate_real_images(dataset, batch, n_patch)
    fakeB, y_fake = generate_fake_images(gen_model, realA, n_patch)
    disc_loss1 = disc_model.train_on_batch([realA, realB], y_real)
    disc_loss2 = disc_model.train_on_batch([realA, fakeB], y_fake)
    gen_loss, _, _ = gan_model.train_on_batch(realA, [y_real, realB])
    print(">%d, d1[%.3f] d2[%.3f] g[%.3f]" %(i+1, disc_loss1, disc_loss2, gen_loss))  
    if (i+1) % (batch_per_epoch*5) == 0 :  
      performance_check(i, gen_model, dataset)    

In [None]:
# Loading dataset

dataset = load_dataset('dataset\cifar10.npz')
print("Loaded", dataset[0].shape, dataset[1].shape)

In [None]:
image_shape = dataset[0].shape[1:]
disc_model = discriminator()
gen_model = generator() #generator(image_shape)
gan_model = gan_model(gen_model, disc_model, image_shape)
train_gan(disc_model, gen_model, gan_model, dataset)