In [None]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import *

def load_train_image():
    pass

def load_test_image():
    pass

def normalize(inimg, tgimg):
    
    #This function normalizes the images of rank (0,255) to (-1, 1)
    inimg = (inimg / 127.5) - 1
    tgimg = (tgimg / 127.5 ) - 1
    
    return inimg, tgimg

def random_jitter(inimg, tgimg):
    
    inimg, tgimg = resize(inimg, tgimg, 286, 286)

    stacked_image = tf.stack([inimg, tgimg], axis=0) #Putting the two images together in order to perform the same actions
    cropped_image = tf.image.random_crop(stacked_imageked_image, size=[2, IMG_HEIGHT, IMG_WEIGTH, 3])
    #The image is cropped randomly. The first size dimension is the number of pictures (two stacked pictures)
    #The last dimension size is the number of channels, in this case 3 (RGB)
    
    inimg, tgimg = cropped_image[0], cropped_image[1]
    
    if tf.random.uniform(()) > 0.5: #randomly flipping the image
        
        inimg = tf.image.flip_left_right(inimg)
        tgimg = tf.image.flip_left_right(tgimg)
        
    return inimg, tgimg

#training set pictures has to be specified!
train_dataset = tf.data.Dataset.from_tensor_slices(training)
train_dataset = train_dataset.map(load_train_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#The map function will apply the function passed as argument to all the pictures
#The num_parallel_calls specify the number of cpus used, if this equals to tf.data.experimetal.AUTOTUNE
#the program will choose the most optimum one
NUM_OF_BATCHES = 1 #PIX2PIX paper use 1!
train_dataset = train_dataset.batch(NUM_OF_BATCHES) #We have to specify the batches in order to perform minibatch optimization    

test_dataset = tf.data.Dataset.from_tensor_slices(test)
test_dataset = test_dataset.map(load_test_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
   
def downsample(filters, apply_batchnorm=True): #Downsampler ~ encoder
    
    result = Sequential()
    
    initializer = tf.random_normal_initializer(0, 0.02)
    
    #Convolutional layer, use bias as long as we are not applying batch normalization since batch normalization add
    # a bias by default!!
    result.add(Conv2D(filters, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer, use_bias=not apply_batchnorm))
    
    #Batch Normalization
    if apply_batchnorm:
        result.add(batchNormalization())
    
    result.add(LeakyReLu)
    
    return result
    
def upsample(filters, apply_dropout=False): #upsampler ~ Dencoder
    
    result = Sequential()
    
    initializer = tf.random_normal_initializer(0, 0.02)
    
    #Convolutional layer, use bias as long as we are not applying batch normalization since batch normalization add
    # a bias by default!!
    result.add(Conv2DTranspose(filters, kernel_size=4, strides=2, padding="same", kernel_initializer=initializer, use_bias=False))
    
    #Batch Normalization
    result.add(batchNormalization())
    
    #Dropout later (regulariazation technique)
    if apply_dropout:
        result.add(Dropout(0.5))
    
    #Activation layer
    result.add(ReLu) #Relu as specified in the paper
    
    return result


def Generator():
    
    inputs = tf.keras.layers.Input(shape=[None, None, 3])
    
    down_stack = [
        downsample(64, apply_batchnorm=False), #(bs, 128, 128, 64) bs = batch_size, reduced to the half
        downsample(128), #(bs, 64, 64, 128) the channels increase due to the number of filters used!
        downsample(256), #(bs, 32, 32, 256)
        downsample(512),
        downsample(512),
        downsample(512),
        downsample(512),
        downsample(512), #(bs, 1, 1, 512)
    ]
    
    up_stack = [
        upsample(512, apply_dropout=True),
        upsample(512, apply_dropout=True),
        upsample(512, apply_dropout=True),
        upsample(512),
        upsample(256),
        upsample(128),
        upsample(64),
    ]
    initializer = tf.random_normal_initializer(0, 0.02)
    last_img = Conv2DTranspose(filters=3, 
                              kernel_size=4,
                              strides=2,
                              padding="same",
                              kernel_initializer=initializer,
                              activation="tanh" #Since we are using normalized images from -1 to 1
                              ) #filters define the channels of the final image, therefore filters=3
    
    x = inputs
    skip_connections = [] #List where we will add different elemnts
    concat = Concatenate()
    for down in down_stack:
        x = down(x)
        skip_connections.append(x)
    
    skip_connections = reversed(skip_connections[:-1]) 
        
    for up, skip in zip(up_stack, skip_connections):
        x = up(x)
        x = concat([x, skip_connections])
        
        last_image = last_image(x)
        
        return Model(inputs=inputs, outputs=last_image) #We have to construct the model object
        
        
def Discriminator():
    
    ini = Input(shape=[None, None, 3], name="input_img")
    gen = Input(shape=[None, None, 3], name="gener_img")
    con = concatenate([ini, gen]) #concatenating the two images generated and input
    
    initializer = tf.random_normal_initializer(0, 0.02)
    down1 = downsample(64, apply_batchnorm=False)(con)
    down2 = downsample(128)(down1)
    down3 = downsample(256)(down2)
    down4 = downsample(512)(down3)
    
    last = tf.keras.layers.Conv2D(filters=1,
                                 kernel_size=4,
                                 strides=1,
                                 kernel_initializer=initializer,
                                 padding="same")(down4) #Only one filter since we are defining for each area of the image
                                 #If it is the image or not so only one channel (PATCHGAN!)
    
    return tf.keras.Model(inputs=[ini, gen], outputs=last)
    


loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True) #From_logits=True we pass the images thru sigmoid function

def discriminator_loss(disc_real_output, disc_generated_output):

    real_loss = loss_object(tf.ones(disc_real_output), disc_real_output) #The Ones matrix mean that they are real!
    
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) #Zeros means false
    
    total_disc_loss = real_loss + generated_loss
    
    return total_disc_loss
    
LAMBDA = 100
    
def generator_loss(disc_generated_output, gen_output, target):
    
    gan_loss = loss_object(tf.ones(disc_real_output), disc_real_output)
    
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    
    total_gen_loss = gan_loss = (LAMBDA * l1_loss)
    
    
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_prefix = os.path.join(CKPATH, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 
                                discriminator_optimizer=discriminator_optimizer,
                                generator=generator,
                                discriminator=discriminator)

def generate_images(model, test_input, tar, save_filename=False, display_imgs=True):
    
    prediction = model(test_input, training=True)
    if save_filename:
        tf.keras.preprocessing.image.save_img(PATH_SAVE)
    
    plt.figure(figsize=(10,10))
    
    display_list = [test_input[0], tar[0], prediction[0]]
    title=["Input image", "Ground Truth", "Predicted Image"]
    
    if display_imgs:
        for i in range(3):
            plt.subplot(1,3,i+1)
            plt.title(title[i])
            plt.imshow(display_list[i] * 0.5 + 0.5)
            plt.axis("off")
            
    plt.show()
    
    
def train_step(input_image, target):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as discr_tape:
    
        output_image = generator(input_imge, training=True)

        output_gen_discr = discriminator([ouput_image, input_image], training=True)

        output_trg_discr = discriminator([target, input_image], training=True)
        
        discr_loss = discriminator_loss(output_trg_discr, output_gen_discr)
        
        gen_loss = generator_loss(output_gen_discr, output_image, target)
    
        generator_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
        
        discriminator_grads = discr_tape.gradient(discr_loss, discriminator.trainable_variables)
    

def train(dataset, epochs):
    for epoch in range(epochs):
        
        imgi = 0
        for input_image, target in dataset:
            train_step(input_image, target)
            
        for inp, tar in test_dataset.take(5):
            generate_images(generator, inp, tar, str(imgi) + "_ " + str(epoch), display_imgs=True)
            
        if (epoch + 1) % 25 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            
    
    
    
    