<a href="https://colab.research.google.com/github/z-arabi/SRU-deeplearning-workshop/blob/master/24_gan_pix2pix_tensorflow_eager.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# pix2pix

Code Source

[https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb)


This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.

In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.

Each epoch takes around 58 seconds on a single P100 GPU.

Below is the output generated after training the model for 200 epochs.


![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)
![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)

## Import TensorFlow and enable eager execution

In [None]:
!pip install --quiet tensorflow==2.8

In [None]:
# eager execution is enabled for TensorFlow >= 2.0
import tensorflow as tf
print(tf.__version__)

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import PIL
from IPython.display import clear_output

## Load the dataset

You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.
* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`
* In random mirroring, the image is randomly flipped horizontally i.e left to right.  

Jittering and mirroring are data augmentation techniques.  
* In random jittering, an image is first resized to a larger dimension (e.g., 286 x 286) and then randomly cropped to its original dimension (e.g., 256 x 256).   

In [None]:
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      cache_subdir=os.path.abspath('.'),
                                      origin='http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz',
                                      extract=True)

print(path_to_zip, os.path.dirname(path_to_zip))

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')

In [None]:
train_path = os.path.join(PATH,"train")
train_images = os.listdir(train_path)
train_images = [os.path.join(train_path, i) for i in train_images]

img_sample = plt.imread(train_images[0])
plt.imshow(img_sample)

In [None]:
img_sample.min(), img_sample.max()

In [None]:
print(img_sample.shape)

fig, axes = plt.subplots(1,2)

axes[0].imshow(img_sample[:,:256,:])
axes[0].set_title("the ground truth")
axes[1].imshow(img_sample[:,256:,:])
axes[1].set_title("the input")

In [None]:
print(len(train_images))

test_path = os.path.join(PATH,"test")
test_images = os.listdir(test_path)
print(len(test_images))

val_path = os.path.join(PATH,"val")
val_images = os.listdir(val_path)
print(len(val_images))

In [None]:
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [None]:
# how to resize the picture, all possible answers:

import cv2
# read BGR
image = cv2.imread(train_images[0])
print("the cv2 reading", type(image))
plt.figure()
plt.imshow(image)
plt.figure()
plt.imshow(image[:,:,-1])
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure()
plt.imshow(image_rgb)
resized_image = cv2.resize(image, (200, 200))
plt.figure()
plt.imshow(resized_image)


from PIL import Image
image = Image.open(train_images[0])
print("the PIL reading", type(image))
plt.figure()
plt.imshow(image)
resized_image = image.resize((200, 200))


from skimage.transform import resize
import imageio
image = imageio.imread(train_images[0])
print("the sklearn reading", type(image))
resized_image = resize(image, (200, 200))


import tensorflow as tf
image = tf.io.read_file(train_images[0])
print("the tf reading", type(image))
image = tf.image.decode_jpeg(image)
print("the tf reading", type(image))
plt.figure()
plt.imshow(image)
# it returns the float points between 0 to 255
resized_image = tf.image.resize(image, [200, 200])
plt.figure()
plt.imshow(resized_image)
# plt imshow needs the int to show
resized_image = tf.cast(resized_image, tf.uint8)
plt.figure()
plt.imshow(resized_image)


from PIL import Image
from torchvision.transforms import Resize, ToTensor, ToPILImage
image = Image.open(train_images[0])
to_tensor = ToTensor() # convert to pytorh tensor
image = to_tensor(image)
print("the pytorch reading", type(image))
resize = Resize((200, 200))
image = resize(image)
to_pil_image = ToPILImage()
image = to_pil_image(image)

In [None]:
def load_image(image_file, is_train):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  if is_train:
    # random jittering

    # resizing to 286 x 286 x 3 > for both input and labels
    # When this flag is set to True, the centers of the corner pixels of the input and output tensors are aligned.
    # This preserves the values at the corner pixels.
    input_image = tf.image.resize(input_image, [286, 286],
                                        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [286, 286],
                                        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    # This stacks tensors along a new axis (axis=0), essentially creating a new tensor that has an additional dimension.
    # The new shape will be [2, height, width, channels]
    # because we want the input and label simultanously crops together we stack them
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    input_image, real_image = cropped_image[0], cropped_image[1]

    if np.random.random() > 0.5:
      # random mirroring
      input_image = tf.image.flip_left_right(input_image)
      real_image = tf.image.flip_left_right(real_image)
  else:
    input_image = tf.image.resize(input_image, size=[IMG_HEIGHT, IMG_WIDTH],
                                         method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, size=[IMG_HEIGHT, IMG_WIDTH],
                                        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # normalizing the images to [-1, 1]
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset

In [None]:
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(lambda x: load_image(x, True))
# each batch will only contain one image
train_dataset = train_dataset.batch(1)

In [None]:
dataset_iterator = iter(train_dataset)
first_batch = next(dataset_iterator)
print("First batch shape:", first_batch[0].shape, first_batch[1].shape)

In [None]:
c = 0
for batch in train_dataset:
    c += 1
print("the whole data size: ", c)

In [None]:
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(lambda x: load_image(x, False))
test_dataset = test_dataset.batch(1)

## Write the generator and discriminator models

* **Generator**
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net). [Skip connections are used to pass information directly from the encoder to the decoder to assist in better reconstruction.]
  
* **Discriminator**
  * The Discriminator is a PatchGAN. [This is a special type of discriminator that classifies whether each patch in an image is real or fake. It doesn't classify the entire image as a whole but instead operates on patches of the image.]
  * Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
  * The shape of the output after the last layer is (batch_size, 30, 30, 1) [ 1 number or 1 bool]
  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).   
  The ouput ill be grided into 30x30. Each cell represent the 70x70 pixels of the input image. In PatchGAN, the 70x70 patches from the 256x256 input image overlap with each other. This means that adjacent patches share a significant number of pixels. This overlapping allows the discriminator to have a more extensive view of each portion of the input image, and it does not need to have a unique output cell for every unique 70x70 patch.  
  When you obtain a 30x30 grid, each of these 900 cells is making a "vote" based on its own 70x70 patch. Since these patches overlap, collectively they provide a more nuanced interpretation of the 256x256 input image. [like the f=70 and s=5]
  * Discriminator receives 2 inputs.
    * Input image and the target image, which it should classify as real.
    * Input image and the generated image (output of generator), which it should classify as fake.
    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)

* Shape of the input travelling through the generator and the discriminator is in the comments in the code.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).
    

In [None]:
OUTPUT_CHANNELS = 3

In [None]:
class Downsample(tf.keras.Model):

  def __init__(self, filters, size, apply_batchnorm=True):
    super(Downsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    # mean and std
    initializer = tf.random_normal_initializer(0., 0.02)

    self.conv1 = tf.keras.layers.Conv2D(filters,
                                        (size, size),
                                        strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)
    if self.apply_batchnorm:
        self.batchnorm = tf.keras.layers.BatchNormalization()

  def call(self, x, training):
    x = self.conv1(x)
    if self.apply_batchnorm:
        # The training parameter is often included in the call method of custom
        # Keras layers or models to specify whether the model is in training mode or
        # inference mode. This is particularly important for layers like BatchNormalization
        # and Dropout, which have different behaviors during training and inference.
        # During training, this layer normalizes its output using the mean and standard
        # deviation of the current batch of inputs. During inference, it uses a running
        # average of the mean and standard deviation it has learned during training.
        # This layer randomly sets a fraction of its input units to 0 during training,
        # which helps to prevent overfitting. During inference, it does nothing.
        x = self.batchnorm(x, training=training)

    # When encoding or downsampling information, it might be beneficial to preserve as much information as possible,
    # including the sign of the data.
    # Leaky ReLU allows a small, non-zero gradient when the input is less than zero.
    # This can help mitigate the "dying ReLU problem," where neurons can sometimes get stuck during training and always output zero.
    x = tf.nn.leaky_relu(x)
    return x


class Upsample(tf.keras.Model):

  def __init__(self, filters, size, apply_dropout=False):
    super(Upsample, self).__init__()
    self.apply_dropout = apply_dropout
    initializer = tf.random_normal_initializer(0., 0.02)

    self.up_conv = tf.keras.layers.Conv2DTranspose(filters,
                                                   (size, size),
                                                   strides=2,
                                                   padding='same',
                                                   kernel_initializer=initializer,
                                                   use_bias=False)
    self.batchnorm = tf.keras.layers.BatchNormalization()
    if self.apply_dropout:
        self.dropout = tf.keras.layers.Dropout(0.5)

  def call(self, x1, x2, training):
    x = self.up_conv(x1)
    x = self.batchnorm(x, training=training)
    if self.apply_dropout:
        x = self.dropout(x, training=training)
    # ReLU activation leads to sparse representations. Sparsity is often beneficial
    # because it can make the network easier to optimize and can lead to a more expressive model.
    # In many architectures like U-Net, ReLU is commonly used in the decoder (upsampling) part.
    x = tf.nn.relu(x)
    # the skip connections > to not forget the information we have some residual connections
    x = tf.concat([x, x2], axis=-1)
    return x


class Generator(tf.keras.Model):

  def __init__(self):
    super(Generator, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)

    # the filter numbers shows the depth > must be growing
    self.down1 = Downsample(64, 4, apply_batchnorm=False)
    self.down2 = Downsample(128, 4)
    self.down3 = Downsample(256, 4)
    self.down4 = Downsample(512, 4)
    self.down5 = Downsample(512, 4)
    self.down6 = Downsample(512, 4)
    self.down7 = Downsample(512, 4)
    self.down8 = Downsample(512, 4)

    self.up1 = Upsample(512, 4, apply_dropout=True)
    self.up2 = Upsample(512, 4, apply_dropout=True)
    self.up3 = Upsample(512, 4, apply_dropout=True)
    self.up4 = Upsample(512, 4)
    self.up5 = Upsample(256, 4)
    self.up6 = Upsample(128, 4)
    self.up7 = Upsample(64, 4)

    # OUTPUT_CHANNELS > 3
    self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS,
                                                (4, 4),
                                                strides=2,
                                                padding='same',
                                                kernel_initializer=initializer)

  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    x1 = self.down1(x, training=training) # (bs, 128, 128, 64)
    x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)
    x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)
    x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)
    x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)
    x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)
    x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)
    x8 = self.down8(x7, training=training) # (bs, 1, 1, 512) > that is the latent size

    x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024) > concat > 512*2
    x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)
    x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)
    x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)
    x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)
    x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)
    x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)

    x16 = self.last(x15) # (bs, 256, 256, 3)
    # We normalized the input to be between 0 and 1, so it is good activation function for the output
    x16 = tf.nn.tanh(x16)

    return x16

In [None]:
class DiscDownsample(tf.keras.Model):

  def __init__(self, filters, size, apply_batchnorm=True):
    super(DiscDownsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    initializer = tf.random_normal_initializer(0., 0.02)

    self.conv1 = tf.keras.layers.Conv2D(filters,
                                        (size, size),
                                        strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)
    if self.apply_batchnorm:
        self.batchnorm = tf.keras.layers.BatchNormalization()

  def call(self, x, training):
    x = self.conv1(x)
    if self.apply_batchnorm:
        x = self.batchnorm(x, training=training)
    x = tf.nn.leaky_relu(x)
    return x

class Discriminator(tf.keras.Model):

  def __init__(self):
    super(Discriminator, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)

    self.down1 = DiscDownsample(64, 4, False)
    self.down2 = DiscDownsample(128, 4)
    self.down3 = DiscDownsample(256, 4)
    # we are zero padding here with 1 because we need our shape to
    # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)
    self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
    self.conv = tf.keras.layers.Conv2D(512,
                                       (4, 4),
                                       strides=1,
                                       kernel_initializer=initializer,
                                       use_bias=False)
    self.batchnorm1 = tf.keras.layers.BatchNormalization()

    # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)
    self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
    self.last = tf.keras.layers.Conv2D(1,
                                       (4, 4),
                                       strides=1,
                                       kernel_initializer=initializer)

  def call(self, inp, tar, training):
    # concatenating the input and the target [real image]
    x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2) > channels=3
    x = self.down1(x, training=training) # (bs, 128, 128, 64)
    x = self.down2(x, training=training) # (bs, 64, 64, 128)
    x = self.down3(x, training=training) # (bs, 32, 32, 256)

    x = self.zero_pad1(x) # (bs, 34, 34, 256)
    x = self.conv(x)      # (bs, 31, 31, 512)
    x = self.batchnorm1(x, training=training)
    x = tf.nn.leaky_relu(x)

    x = self.zero_pad2(x) # (bs, 33, 33, 512)
    # don't add a sigmoid activation here since
    # the loss function expects raw logits.
    x = self.last(x)      # (bs, 30, 30, 1)

    return x

In [None]:
# The call function of Generator and Discriminator have been decorated
# with tf.contrib.eager.defun()
# We get a performance speedup if defun is used (~25 seconds per epoch)
generator = Generator()
discriminator = Discriminator()

# where is mentioned that the patch is 70*70 ??????

## Receptive Field Calculation

To calculate the effective receptive field, we work our way from the output layer back to the input layer. The formula for calculating the receptive field $ \text{RF} $ is:

$$
\text{RF} = \text{RF}_{\text{prev}} + (\text{Kernel Size} - 1) \times \text{Stride}_{\text{prod}}
$$

Where $ \text{RF}_{\text{prev}} $ is the receptive field of the previous layer and $ \text{Stride}_{\text{prod}} $ is the product of the strides of all preceding layers.

For this architecture, assuming all Conv2D and Conv2DTranspose layers use 4x4 kernels, let's calculate:

1. **Last Layer (30x30 output)**: 4x4 kernel, stride 1  
    - $ \text{RF} = 1 + (4 - 1) \times 1 = 4 $

2. **Layer Before Last**: (zero padding and leaky ReLU do not affect RF)
    - $ \text{RF} = 4 + (4 - 1) \times 1 \times 1 = 7 $

3. **32x32 Layer**: 4x4 kernel, stride 2
    - $ \text{RF} = 7 + (4 - 1) \times 1 \times 1
     \times 2 = 13 $

4. **64x64 Layer**: 4x4 kernel, stride 2
    - $ \text{RF} = 13 + (4 - 1) \times 1 \times 1
     \times 2 \times 2 = 25 $

5. **128x128 Layer**: 4x4 kernel, stride 2
    - $ \text{RF} = 25 + (4 - 1) \times 1 \times 1
     \times 2 \times 2 \times 2 = 49 $

Note: Zero padding layers, batch normalization, and activation functions (ReLU, Leaky ReLU) do not affect the receptive field size, so they are omitted in the calculation.

For this dataset and with this architecture the portion of each grid is 49*49 of the original image.

## Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images, generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss
  
  The loss in the difference of the input and the label provided for that.  
  real image > label=1  
  generated image > label=0  
  You can still use binary cross-entropy as your loss function even for a PatchGAN discriminator, because the binary cross-entropy can be applied to each patch independently.

* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
LAMBDA = 100

In [None]:
# disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
def discriminator_loss(real_image, generated_image):
  # from_logits > the output is the raw, unnormalized scores output by a model, can be any real number, positive, negative, or zero.
  # Probabilities: Values between 0 and 1 obtained by applying the sigmoid function to the logits, representing the model’s confidence in a particular class.
  loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

  # Compute the loss
  # tf.ones_like(disc_real_output) creates a tensor that has the same shape as disc_real_output but is filled with ones.
  # computes the Binary Cross-Entropy loss between the true labels (all ones) and the predicted labels.
  # The total loss is the average of all these individual losses.
  # real_image > b*256*256*3
  # tf.one_like > b*256*256*3 > put the label one for each pixel
  # OR 30*30*1 >> it is correct
  real_loss = loss_fn(tf.ones_like(real_image), real_image)
  generated_loss = loss_fn(tf.zeros_like(generated_image), generated_image)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

In [None]:
# gen_loss = generator_loss(disc_generated_output, gen_output, target)
def generator_loss(generated_image, gen_output, target):
  loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

  gan_loss = loss_fn(tf.ones_like(generated_image), generated_image)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

In [None]:
generator_optimizer = tf.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
# creating object to handle saving and restoring the state of the models and their corresponding optimizers.
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

* We start by iterating over the dataset
* The generator gets the input image and we get a generated output.
* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.
* Next, we calculate the generator and the discriminator loss.
* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.

## Generate Images

* After training, its time to generate some images!
* We pass images from the test dataset to the generator.
* The generator will then translate the input image into the output we expect.
* Last step is to plot the predictions and **voila!**

In [None]:
EPOCHS = 200

In [None]:
def generate_images(model, test_input, tar):
  # the training=True is intentional here since
  # we want the batch statistics while running the model
  # on the test dataset. If we use training=False, we will get
  # the accumulated statistics learned from the training dataset
  # (which we don't want)
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for input_image, target in dataset:

      # When you’re training a model, you use a technique called backpropagation, which requires you to
      # compute gradients or partial derivatives of the loss function with respect to the model parameters (or weights).
      # This line is creating two separate gradient tapes: gen_tape and disc_tape. gen_tape will record
      # operations for the generator, and disc_tape will record operations for the discriminator.
      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        # now disriminative updates two times, first with the real data, second with the generated one
        disc_real_output = discriminator(input_image, target, training=True)
        disc_generated_output = discriminator(input_image, gen_output, training=True)

        gen_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

      # for calculating GD we can exit the block
      # This calculates the gradients of the losses with respect to the trainable variables
      generator_gradients = gen_tape.gradient(gen_loss,
                                              generator.trainable_variables)
      discriminator_gradients = disc_tape.gradient(disc_loss,
                                                   discriminator.trainable_variables)

      # updating the weights of the generator and the discriminator.
      generator_optimizer.apply_gradients(zip(generator_gradients,
                                              generator.trainable_variables))
      discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                  discriminator.trainable_variables))

    if epoch % 1 == 0:
        clear_output(wait=True)
        for inp, tar in test_dataset.take(1):
          generate_images(generator, inp, tar)

    # saving (checkpoint) the model every 20 epochs
    if (epoch + 1) % 20 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))

In [None]:
train(train_dataset, EPOCHS)

## Restore the latest checkpoint and test

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Testing on the entire test dataset

In [None]:
# Run the trained model on the entire test dataset
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)