In [37]:
from numpy import zeros
from numpy import ones
from numpy.random import randint
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
#from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
from matplotlib import pyplot as plt
from tensorflow.keras.utils import plot_model
from keras.layers import Input

In [38]:
def define_discriminator(image_shape):
    init=RandomNormal(stddev=0.02)  #weight initialization
    in_src_image=Input(shape=image_shape)
    in_target_image=Input(shape=image_shape)
    merged=Concatenate()([in_src_image,in_target_image])

    d=Conv2D(64,(4,4), strides=(2,2), padding="same", kernel_initializer=init)(merged)
    d=LeakyReLU(alpha=0.2)(d)

    d=Conv2D(128,(4,4),strides=(2,2),padding="same",kernel_initializer=init)(d)
    d=BatchNormalization()(d)
    d=LeakyReLU(alpha=0.2)(d)

    d=Conv2D(256,(4,4),strides=(2,2),padding="same",kernel_initializer=init)(d)
    d=BatchNormalization()(d)
    d=LeakyReLU(alpha=0.2)(d)

    d=Conv2D(512,(4,4),strides=(2,2),padding="same",kernel_initializer=init)(d)
    d=BatchNormalization()(d)
    d=LeakyReLU(alpha=0.2)(d)

    d=Conv2D(512,(4,4),padding="same",kernel_initializer=init)(d) #stride =(1,1)
    d=BatchNormalization()(d)
    d=LeakyReLU(alpha=0.2)(d)

    #patch out
    d=Conv2D(1,(4,4),padding='same',kernel_initializer=init)(d)
    patch_out=Activation('sigmoid')(d)

    #define model
    model=Model([in_src_image,in_target_image],patch_out)

    #Compile model
    opt=Adam(learning_rate=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy',optimizer=opt,loss_weights=[0.5])
    return model
    

In [39]:
test_discr=define_discriminator((64,64,3))
print(test_discr.summary())



None


In [40]:
def define_encoder_block(layer_in, n_filters, batchnorm=True):
    #weight initialization
    init=RandomNormal(stddev=0.02)
    #add downsampling layer
    g=Conv2D(n_filters,(4,4),strides=(2,2),padding="same",kernel_initializer=init)(layer_in)
    #conditionally and batch normalization
    if batchnorm:
        g=BatchNormalization()(g,training=True)
    g=LeakyReLU(alpha=0.2)(g)
    return g
    

In [41]:
def decoder_block(layer_in,skip_in,n_filters,dropout=True):
    #weight initialization
    init=RandomNormal(stddev=0.02)
    g=Conv2DTranspose(n_filters,(4,4),strides=(2,2),padding='same',kernel_initializer=init)(layer_in)
    g=BatchNormalization()(g,training=True)
    if dropout:
        g=Dropout(0.5)(g,training=True)
    #merge with skip connection
    g=Concatenate()([g,skip_in])
    g=Activation('relu')(g)
    return g

In [42]:
def define_generator(image_shape=(256,256,3)):
    init=RandomNormal(stddev=0.02)
    in_image=Input(shape=image_shape)

    #encoder model
    e1=define_encoder_block(in_image,64,batchnorm=False)
    e2=define_encoder_block(e1,128)
    e3=define_encoder_block(e2,256)
    e4=define_encoder_block(e3,512)
    e5=define_encoder_block(e4,512)
    e6=define_encoder_block(e5,512)
    e7=define_encoder_block(e6,512)

    #bottleneck
    b=Conv2D(512,(4,4),strides=(2,2),padding='same',kernel_initializer=init)(e7)
    b=Activation('relu')(b)

    #decoder model
    d1=decoder_block(b,e7,512)
    d2=decoder_block(d1,e6,512)
    d3=decoder_block(d2,e5,512)
    d4=decoder_block(d3,e4,512,dropout=False)
    d5=decoder_block(d4,e3,256,dropout=False)
    d6=decoder_block(d5,e2,128,dropout=False)
    d7=decoder_block(d6,e1,64,dropout=False)

    #output
    g=Conv2DTranspose(image_shape[2],(4,4),strides=(2,2),padding='same',kernel_initializer=init)(d7)
    out_image=Activation('tanh')(g)
    model=Model(in_image,out_image)
    return model

In [43]:
gen_model=define_generator()
#plot_model(gen_model,to_file='gen_model.png',show_shapes=True)
print(gen_model.summary())

None


In [44]:
def define_gan(g_model,d_model,image_shape):
    #make weights in the discriminator not trainable
    for layer in d_model.layers:
        if not isinstance(layer,BatchNormalization):
            layer.trainable=False
    
    in_src=Input(shape=image_shape)
    gen_out=g_model(in_src)
    dis_out=d_model([in_src,gen_out])
    
    model=Model(in_src,[dis_out,gen_out])
    opt=Adam(learning_rate=0.0002,beta_1=0.5)
    
    model.compile(loss=['binary_crossentropy','mae'],optimizer=opt,loss_weights=[1,100])
    return model

In [45]:
def generate_real_samples(dataset,n_samples,patch_shape):
    trainA,trainB=dataset
    ix=randint(0,trainA.shape[0],n_samples)
    X1,X2=trainA[ix],trainB[ix]
    y=ones((n_samples,patch_shape,patch_shape,1))
    return [X1,X2],y

In [46]:
def generate_fake_smples(g_model,samples,patch_shape):
    X=g_model.predict(samples)
    y=zeros((len(X),patch_shape,patch_shape,1))
    return X,y

In [47]:
def summarize_performance(step,g_model,dataset,n_samples=3):
    [X_realA,X_realB],_=generate_real_samples(dataset,n_samples,1)
    X_fakeB,_=generate_fake_smples(g_model,X_realA,1)
    X_realA=(X_realA+1)/2.0
    X_realB=(X_realB+1)/2.0
    X_fakeB=(X_fakeB+1)/2.0
    
    #plot real images
    for i in range(n_samples):
        plt.subplot(3,n_samples,1+i)
        plt.axis('off')
        plt.imshow(X_realA[i])
    
    #plot generated target images
    for i in range(n_samples):
        plt.subplot(3,n_samples,1+n_samples+i)
        plt.axis('off')
        plt.imshow(X_fakeB[i])
        
    #plot real target images
    for i in range(n_samples):
        plt.subplot(3,n_samples*2,1+i)
        plt.axis('off')
        plt.imshow(X_realB[i])
        
    #save generated model
    file_name='pix2pix_model_%06d.h5' %(step+1)
    g_model.save(file_name)

In [48]:
def train(d_model,g_model,gan_model,datset,n_epochs=100,n_batch=1):
    n_patch=d_model.output_shape[1]
    trainA,trainB=dataset
    bat_per_epo=int(len(trainA)/n_batch)
    n_steps=bat_per_epo*n_epochs
    for i in range(n_steps):
        [X_realA,X_realB],y_real=generate_real_samples(dataset,n_batch,n_patch)
        X_fakeB,y_fake=generate_fake_smples(g_model,X_realA,n_patch)
        d_loss1=d_model.train_on_batch([X_realA,X_realB],y_real)
        d_loss2=d_model.train_on_batch([X_realA,X_fakeB],y_fake)
        g_loss,_,_=gan_model.train_on_batch(X_realA,[y_real,X_realB])
        print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
        if (i+1) % (bat_per_epo * 10) == 0:
            summarize_performance(i, g_model, dataset)

In [49]:
from os import listdir
import numpy
from numpy import asarray,load , vstack ,savez_compressed
from keras.preprocessing.image import img_to_array,load_img
from matplotlib import pyplot
import cv2 as cv
import os

In [50]:
def load_images(path):
    img_files=os.listdir(path)
    print(len(img_files))
    images=[]
    src_list=[]
    tar_list=[]
    for image_file in img_files:
        img_path=os.path.join(path,image_file)
        img=pyplot.imread(img_path)
        img=cv.resize(img,(512,256))
        src_list.append(img[:,:256])
        tar_list.append(img[:,256:])
    return [asarray(src_list),asarray(tar_list)]

In [51]:
path='maps//train//'

In [52]:
[src_images,tar_images]=load_images(path)

1096


In [53]:
src_images.shape

(1096, 256, 256, 3)

In [54]:
tar_images.shape

(1096, 256, 256, 3)

In [55]:
#pyplot.imshow(src_images[0])

In [56]:
#pyplot.imshow(tar_images[0])

In [57]:
image_shape=src_images[0].shape
image_shape

(256, 256, 3)

In [59]:
d_model=define_discriminator(image_shape)
g_model=define_generator(image_shape)
gan_model=define_gan(g_model,d_model,image_shape)



In [60]:
#data=[src_images,tar_images]

In [61]:
def preprocess_data(data):
    X1,X2=data[0],data[1]
    X1=(X1-127.5)/127.5
    X2=(X2-127.5)/127.5
    return [X1,X2]


In [63]:
dataset=preprocess_data([src_images,tar_images])

In [62]:
from datetime import datetime
start1=datetime.now()

In [None]:
train(d_model,g_model,gan_model,dataset,n_epochs=10,n_batch=1)

In [64]:
stop1=datetime.now()

In [65]:
execution_time=start1-stop1
execution_time

datetime.timedelta(days=-1, seconds=86256, microseconds=335899)