# derived from https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/

# https://machinelearningmastery.com/how-to-develop-a-pix2pix-gan-for-image-to-image-translation/

In [1]:
from keras.datasets.fashion_mnist import load_data
# load the images into memory
(trainX, trainy), (testX, testy) = load_data()
# summarize the shape of the dataset
print('Train', trainX.shape, trainy.shape)
print('Test', testX.shape, testy.shape)

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Train (60000, 28, 28) (60000,)
Test (10000, 28, 28) (10000,)


In [23]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

def receptive_field(output_size, kernel_size, stride_size):
    return (output_size - 1) * stride_size + kernel_size

# example of defining a 70x70 patchgan discriminator model
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D, Conv2DTranspose, Dropout
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import BatchNormalization
from keras.utils.vis_utils import plot_model
import numpy as np

# define the discriminator model
def define_discriminator(image_shape):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # source image input
    in_src_image = Input(shape=image_shape)
    # target image input
    in_target_image = Input(shape=image_shape)
    # concatenate images channel-wise
    merged = Concatenate()([in_src_image, in_target_image])
    # C64
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
    d = LeakyReLU(alpha=0.2)(d)
    # C128
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C256
    d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # C512
    d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # second last output layer
    d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d)
    d = LeakyReLU(alpha=0.2)(d)
    # patch output
    d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
    patch_out = Activation('sigmoid')(d)
    # define model
    model = Model([in_src_image, in_target_image], patch_out)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
    return model

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
# weight initialization
    init = RandomNormal(stddev=0.02)
    # add downsampling layer
    g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    # conditionally add batch normalization
    if batchnorm:
        g = BatchNormalization()(g, training=True)
    # leaky relu activation
    g = LeakyReLU(alpha=0.2)(g)
    return g
 
# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # add upsampling layer
    g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    # add batch normalization
    g = BatchNormalization()(g, training=True)
    # conditionally add dropout
    if dropout:
        g = Dropout(0.5)(g, training=True)
    # merge with skip connection
    g = Concatenate()([g, skip_in])
    # relu activation
    g = Activation('relu')(g)
    return g
 
# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
    # weight initialization
    init = RandomNormal(stddev=0.02)
    # image input
    in_image = Input(shape=image_shape)
    # encoder model: C64-C128-C256-C512-C512-C512-C512-C512
    e1 = define_encoder_block(in_image, 64, batchnorm=False)
    e2 = define_encoder_block(e1, 128)
    e3 = define_encoder_block(e2, 256)
    e4 = define_encoder_block(e3, 512)
    e5 = define_encoder_block(e4, 512)
    e6 = define_encoder_block(e5, 512)
    e7 = define_encoder_block(e6, 512)
    # bottleneck, no batch norm and relu
    b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
    b = Activation('relu')(b)
    # decoder model: CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
    d1 = decoder_block(b, e7, 512)
    d2 = decoder_block(d1, e6, 512)
    d3 = decoder_block(d2, e5, 512)
    d4 = decoder_block(d3, e4, 512, dropout=False)
    d5 = decoder_block(d4, e3, 256, dropout=False)
    d6 = decoder_block(d5, e2, 128, dropout=False)
    d7 = decoder_block(d6, e1, 64, dropout=False)
    # output
    g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
    out_image = Activation('tanh')(g)
    # define model
    model = Model(in_image, out_image)
    return model
 

In [36]:
def define_gan(g_model, d_model, image_shape):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # define the source image
    in_src = Input(shape=image_shape)
    # connect the source image to the generator input
    gen_out = g_model(in_src)
    # connect the source input and generator output to the discriminator input
    dis_out = d_model([in_src, gen_out])
    # src image as input, generated image and classification output
    model = Model(in_src, [dis_out, gen_out])
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
    return model

# define image shape
image_shape = (256,256,3)
# define the models
d_model = define_discriminator(image_shape)
g_model = define_generator(image_shape)
# define the composite model
gan_model = define_gan(g_model, d_model, image_shape)
# summarize the model
gan_model.summary()
# plot the model
plot_model(gan_model, to_file='gan_model_plot.png', show_shapes=True, show_layer_names=True)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_17 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
model_8 (Model)                 (None, 256, 256, 3)  54429315    input_17[0][0]                   
__________________________________________________________________________________________________
model_7 (Model)                 (None, 16, 16, 1)    6968257     input_17[0][0]                   
                                                                 model_8[1][0]                    
Total params: 61,397,572
Trainable params: 54,419,459
Non-trainable params: 6,978,113
__________________________________________________________________________________________________


In [37]:
def generate_real_samples(dataset, n_samples, patch_shape):
    # unpack dataset
    trainA, trainB = dataset
    # choose random instances
    ix = np.random.randint(0, trainA.shape[0], n_samples)
    # retrieve selected images
    X1, X2 = trainA[ix], trainB[ix]
    # generate 'real' class labels (1)
    y = np.ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
    # generate fake instance
    X = g_model.predict(samples)
    # create 'fake' class labels (0)
    y = np.zeros((len(X), patch_shape, patch_shape, 1))
    return X, y


In [38]:
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1, n_patch=16):
    # unpack dataset
    trainA, trainB = dataset
    # calculate the number of batches per training epoch
    bat_per_epo = int(len(trainA) / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # manually enumerate epochs
    for i in range(n_steps):
        # select a batch of real samples
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
        # generate a batch of fake samples
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        # update discriminator for real samples
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        # update discriminator for generated samples
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        # update the generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
        # summarize performance
        print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))

In [39]:
#dataset = (trainX,trainX)

x = np.random.rand(1000, 256, 256, 3)
y = np.random.rand(1000, 256, 256, 3)
train(d_model, g_model, gan_model, (x,y))

  'Discrepancy between trainable weights and collected trainable'


>1, d1[0.344] d2[0.753] g[55.212]
>2, d1[0.269] d2[0.623] g[53.178]
>3, d1[0.338] d2[0.520] g[51.122]
>4, d1[0.393] d2[0.472] g[49.105]
>5, d1[0.399] d2[0.434] g[47.081]
>6, d1[0.395] d2[0.421] g[45.112]
>7, d1[0.409] d2[0.392] g[42.980]
>8, d1[0.378] d2[0.409] g[41.180]
>9, d1[0.379] d2[0.393] g[39.347]
>10, d1[0.354] d2[0.378] g[37.829]
>11, d1[0.333] d2[0.357] g[36.456]
>12, d1[0.315] d2[0.317] g[35.115]
>13, d1[0.319] d2[0.308] g[34.183]
>14, d1[0.282] d2[0.278] g[33.212]
>15, d1[0.270] d2[0.251] g[32.444]
>16, d1[0.249] d2[0.289] g[31.940]
>17, d1[0.258] d2[0.222] g[31.314]
>18, d1[0.256] d2[0.220] g[30.894]
>19, d1[0.220] d2[0.177] g[30.496]
>20, d1[0.184] d2[0.275] g[30.739]
>21, d1[0.629] d2[0.436] g[29.910]
>22, d1[0.406] d2[0.195] g[28.752]
>23, d1[0.227] d2[0.538] g[29.672]
>24, d1[1.009] d2[0.075] g[27.994]
>25, d1[0.446] d2[0.348] g[27.740]
>26, d1[0.217] d2[0.527] g[28.282]
>27, d1[0.697] d2[0.193] g[27.458]
>28, d1[0.579] d2[0.757] g[27.447]
>29, d1[0.656] d2[0.263] g[27

>234, d1[0.009] d2[0.055] g[29.585]
>235, d1[0.039] d2[0.033] g[29.499]
>236, d1[0.030] d2[0.038] g[29.361]
>237, d1[0.013] d2[0.047] g[29.248]
>238, d1[0.008] d2[0.046] g[29.345]
>239, d1[0.011] d2[0.049] g[29.430]
>240, d1[0.014] d2[0.063] g[29.502]
>241, d1[0.024] d2[0.091] g[29.921]
>242, d1[0.103] d2[0.262] g[31.211]
>243, d1[0.565] d2[0.163] g[29.572]
>244, d1[0.005] d2[0.082] g[30.416]
>245, d1[0.015] d2[0.038] g[30.630]
>246, d1[0.038] d2[0.069] g[30.428]
>247, d1[0.026] d2[0.099] g[31.115]
>248, d1[0.179] d2[0.980] g[32.935]
>249, d1[2.243] d2[0.006] g[31.849]
>250, d1[2.054] d2[0.023] g[29.938]
>251, d1[1.380] d2[0.100] g[28.378]
>252, d1[0.757] d2[0.267] g[27.464]
>253, d1[0.434] d2[0.398] g[27.080]
>254, d1[0.355] d2[0.391] g[26.964]
>255, d1[0.362] d2[0.355] g[26.911]
>256, d1[0.365] d2[0.333] g[26.810]
>257, d1[0.360] d2[0.326] g[26.666]
>258, d1[0.347] d2[0.325] g[26.605]
>259, d1[0.324] d2[0.320] g[26.565]
>260, d1[0.314] d2[0.309] g[26.546]
>261, d1[0.304] d2[0.301] g[

KeyboardInterrupt: 

In [35]:
x.shape

(1000, 256, 256, 3)