# Pix2Pix

* `Image-to-Image Translation with Conditional Adversarial Networks`, [arXiv:1611.07004](https://arxiv.org/abs/1611.07004)
  * Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, and Alexei A. Efros

* This code is available to tensorflow version 2.0
* Implemented by [`tf.keras.layers`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers) [`tf.losses`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/losses)
* This code is borrowed from [TensorFlow Tutorial code](https://www.tensorflow.org/alpha/tutorials/generative/pix2pix)

## 1. Import modules

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import sys
import time
import glob

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import PIL
import imageio
from IPython import display

import tensorflow as tf
from tensorflow.keras import layers

sys.path.append(os.path.dirname(os.path.abspath('.')))
from utils.image_utils import *
from utils.ops import *

os.environ["CUDA_VISIBLE_DEVICES"]="0"

## 2. Setting hyperparameters

In [None]:
# Training Flags (hyperparameter configuration)
model_name = 'pix2pix'
train_dir = os.path.join('train', model_name, 'exp1')

max_epochs = 200
save_model_epochs = 20
print_steps = 50
save_images_epochs = 5
batch_size = 1
learning_rate_D = 2e-4
learning_rate_G = 2e-4

IMG_WIDTH = 256
IMG_HEIGHT = 256
LAMBDA = 100

## 3. Load the dataset

You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.

* In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [None]:
DATASETS = ["facades",
            "cityscapes"]

dataset_name = "facades"
#dataset_name = "cityscapes"

_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/' + dataset_name + '.tar.gz'
path_to_zip = tf.keras.utils.get_file(dataset_name + '.tar.gz',
                                      cache_subdir=os.path.abspath('../datasets'),
                                      origin=_URL,
                                      extract=True)

if not dataset_name + '_pix2pix' in os.listdir('../datasets/'):
  os.rename(os.path.join('../datasets/', dataset_name),
            os.path.join('../datasets/', dataset_name + '_pix2pix'))

PATH = os.path.join(os.path.dirname(path_to_zip), dataset_name + '_pix2pix')

## 4. Set up dataset with `tf.data`

### Image augmentation

In [None]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

In [None]:
inp, re = load(os.path.join(PATH, 'train/100.jpg'))
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)

In [None]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [None]:
# normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [None]:
# As you can see in the images below
# that they are going through random jittering
# Random jittering as described in the paper is to
# 1. Resize an image to bigger height and width
# 2. Randomnly crop to the original size
# 3. Randomnly flip the image horizontally

plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()

In [None]:
def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

In [None]:
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

### Input pipeline

* Use tf.data to create batches, map(do preprocessing) and shuffle the dataset

In [None]:
file_list = os.listdir(os.path.join(PATH, 'train'))
N = BUFFER_SIZE = len(file_list) # number of samples in train_dataset

train_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'train/*.jpg'))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train,
                                  #num_parallel_calls=tf.data.experimental.AUTOTUNE) # Error of out of memory
                                  num_parallel_calls=16)
train_dataset = train_dataset.batch(batch_size)

In [None]:
val_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'val/*.jpg'))
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
val_dataset = val_dataset.shuffle(BUFFER_SIZE)
val_dataset = val_dataset.map(load_image_test)
val_dataset = val_dataset.batch(batch_size)

In [None]:
if dataset_name in ["facades"]:
  test_dataset = tf.data.Dataset.list_files(os.path.join(PATH, 'test/*.jpg'))
  # shuffling so that for every epoch a different image is generated
  # to predict and display the progress of our model.
  test_dataset = test_dataset.shuffle(BUFFER_SIZE)
  test_dataset = test_dataset.map(load_image_test)
  test_dataset = test_dataset.batch(batch_size)

## 5. Write the generator and discriminator models

### Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).

In [None]:
class Downsample(tf.keras.Model):
  def __init__(self, filters, size, apply_batchnorm=True):
    super(Downsample, self).__init__()
    self.apply_batchnorm = apply_batchnorm
    self.conv = layers.Conv2D(filters=filters, kernel_size=size,
                              strides=2, padding='same',
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=False)
    if self.apply_batchnorm:
      self.batchnorm = layers.BatchNormalization()

    self.leaky_relu = layers.LeakyReLU(alpha=0.2)
    
  def call(self, x, training):
    x = self.conv(x)
    if self.apply_batchnorm:
      x = self.batchnorm(x, training=training)
    x = self.leaky_relu(x)
    
    return x

In [None]:
class Upsample(tf.keras.Model):
  def __init__(self, filters, size, apply_dropout=False):
    super(Upsample, self).__init__()
    self.apply_dropout = apply_dropout
    self.up_conv = layers.Conv2DTranspose(filters=filters, kernel_size=size,
                                          strides=2, padding='same',
                                          kernel_initializer=tf.random_normal_initializer(0., 0.02),
                                          use_bias=False)
    self.batchnorm = layers.BatchNormalization()
    
    if self.apply_dropout:
      self.dropout = layers.Dropout(0.5)
    self.relu = layers.ReLU()

  def call(self, x1, x2, training):
    x = self.up_conv(x1)
    x = self.batchnorm(x, training=training)
    if self.apply_dropout:
      x = self.dropout(x, training=training)
    x = self.relu(x)
    x = tf.concat([x, x2], axis=-1)
    
    return x

In [None]:
class Generator(tf.keras.Model):
  def __init__(self):
    super(Generator, self).__init__()
    self.down1 = Downsample(64, 4, apply_batchnorm=False)
    self.down2 = Downsample(128, 4)
    self.down3 = Downsample(256, 4)
    self.down4 = Downsample(512, 4)
    self.down5 = Downsample(512, 4)
    self.down6 = Downsample(512, 4)
    self.down7 = Downsample(512, 4)
    self.down8 = Downsample(512, 4)

    self.up1 = Upsample(512, 4, apply_dropout=True)
    self.up2 = Upsample(512, 4, apply_dropout=True)
    self.up3 = Upsample(512, 4, apply_dropout=True)
    self.up4 = Upsample(512, 4)
    self.up5 = Upsample(256, 4)
    self.up6 = Upsample(128, 4)
    self.up7 = Upsample(64, 4)

    self.last = layers.Conv2DTranspose(filters=3,
                                       kernel_size=(4, 4),
                                       strides=2,
                                       padding='same',
                                       kernel_initializer=tf.random_normal_initializer(0., 0.02))
  
  def call(self, x, training):
    # x shape == (bs, 256, 256, 3)
    x1 = self.down1(x, training=training) # (bs, 128, 128, 64)
    x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)
    x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)
    x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)
    x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)
    x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)
    x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)
    x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)

    x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)
    x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)
    x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)
    x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)
    x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)
    x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)
    x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)

    x16 = self.last(x15) # (bs, 256, 256, 3)
    generated_images = tf.nn.tanh(x16)

    return generated_images

In [None]:
generator = Generator()

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

### Discriminator

* The Discriminator is a PatchGAN.
* Each block in the discriminator is (Conv -> BatchNorm -> Leaky ReLU)
* The shape of the output after the last layer is (batch_size, 30, 30, 1)
* Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).
* Discriminator receives 2 inputs.
  * Input image and the target image, which it should classify as real.
  * Input image and the generated image (output of generator), which it should classify as fake.
  * We concatenate these 2 inputs together in the code (tf.concat([inp, tar], axis=-1))
* Shape of the input travelling through the generator and the discriminator is in the comments in the code.

To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.down1 = Downsample(64, 4, False)
    self.down2 = Downsample(128, 4)
    self.down3 = Downsample(256, 4)
    
    # we are zero padding here with 1 because we need our shape to 
    # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)
    self.zero_pad1 = layers.ZeroPadding2D() # (bs, 34, 34, 256)
    self.conv = layers.Conv2D(filters=512, kernel_size=4,
                              kernel_initializer=tf.random_normal_initializer(0., 0.02),
                              use_bias=False)
    self.batchnorm = layers.BatchNormalization()
    self.leaky_relu = layers.LeakyReLU(alpha=0.2)
    
    # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)
    self.zero_pad2 = layers.ZeroPadding2D() # (bs, 33, 33, 512)
    self.last = layers.Conv2D(filters=1, kernel_size=4,
                              kernel_initializer=tf.random_normal_initializer(0., 0.02))

  def call(self, inputs, targets, training):
    # concatenating the input and the target
    x = tf.concat([inputs, targets], axis=-1) # (bs, 256, 256, channels*2)
    x = self.down1(x, training=training) # (bs, 128, 128, 64)
    x = self.down2(x, training=training) # (bs, 64, 64, 128)
    x = self.down3(x, training=training) # (bs, 32, 32, 256)

    x = self.zero_pad1(x) # (bs, 34, 34, 256)
    x = self.conv(x)      # (bs, 31, 31, 512)
    x = self.batchnorm(x, training=training)
    x = self.leaky_relu(x)
    
    x = self.zero_pad2(x) # (bs, 33, 33, 512)
    # don't add a sigmoid activation here since
    # the loss function expects raw logits.
    x = self.last(x)      # (bs, 30, 30, 1)

    return x

In [None]:
discriminator = Discriminator()
disc_out = discriminator(inp[tf.newaxis,...], gen_output, training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()

## 6. Model summary

In [None]:
generator.summary()

In [None]:
discriminator.summary()

## 7. Define the loss functions and the optimizer

* **Discriminator loss**
  * The discriminator loss function takes 2 inputs; **real images**, **generated images**
  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**
  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**
  * Then the total_loss is the sum of real_loss and the generated_loss
* **Generator loss**
  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.
  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.
  * This allows the generated image to become structurally similar to the target image.
  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real_logits, fake_logits):
  real_loss = loss_object(tf.ones_like(real_logits), real_logits)
  fake_loss = loss_object(tf.zeros_like(fake_logits), fake_logits)

  return real_loss + fake_loss

In [None]:
def generator_loss(fake_logits, gen_output, target):
  gan_loss = loss_object(tf.ones_like(fake_logits), fake_logits)

  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  return gan_loss + (LAMBDA * l1_loss)

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate_D, beta_1=0.5)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate_G, beta_1=0.5)

## 8. Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = train_dir
if not tf.io.gfile.exists(checkpoint_dir):
  tf.io.gfile.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## 9. Define generate_and_print_or_save functions

In [None]:
def generate_and_print_or_save(inputs, targets,
                               is_save=False, epoch=None, checkpoint_dir=checkpoint_dir):

  prediction = generator(inputs, training=True)
  print_or_save_sample_images_pix2pix(inputs, targets, prediction,
                                      model_name='pix2pix', name=None,
                                      is_save=is_save, epoch=epoch, checkpoint_dir=checkpoint_dir)

In [None]:
# keeping the constant test input for generation (prediction) so
# it will be easier to see the improvement of the pix2pix.
for inputs, targets in val_dataset.take(1):
  constant_val_input = inputs
  constant_val_target = targets

In [None]:
# Check for test data
generate_and_print_or_save(constant_val_input, constant_val_target)

## 10. Training

### Define training one step function

In [None]:
@tf.function()
def train_step(input_images, targets):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(input_images, training=True)

    real_logits = discriminator(input_images, targets, training=True)
    fake_logits = discriminator(input_images, generated_images, training=True)

    gen_loss = generator_loss(fake_logits, generated_images, targets)
    disc_loss = discriminator_loss(real_logits, fake_logits)

  generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  
  return gen_loss, disc_loss

### Training until max_epochs

In [None]:
print('Start Training.')
num_batches_per_epoch = int(N / batch_size)
global_step = tf.Variable(0, trainable=False)

for epoch in range(max_epochs):
  
  for step, (input_image, target) in enumerate(train_dataset):
    start_time = time.time()
    
    gen_loss, disc_loss = train_step(input_image, target)
    global_step.assign_add(1)

    if global_step.numpy() % print_steps == 0:
      epochs = epoch + step / float(num_batches_per_epoch)
      duration = time.time() - start_time
      examples_per_sec = batch_size / float(duration)
      display.clear_output(wait=True)
      print("Epochs: {:.2f} global_step: {} loss_D: {:.3g} loss_G: {:.3g} ({:.2f} examples/sec; {:.3f} sec/batch)".format(
                epochs, global_step.numpy(), disc_loss, gen_loss, examples_per_sec, duration))
      # generate sample image from random test image
      # the training=True is intentional here since
      # we want the batch statistics while running the model
      # on the test dataset. If we use training=False, we will get 
      # the accumulated statistics learned from the training dataset
      # (which we don't want)
      for val_input, val_target in val_dataset.take(1):
        generate_and_print_or_save(val_input, val_target)

  if (epoch + 1) % save_images_epochs == 0:
    display.clear_output(wait=True)
    print("This images are saved at {} epoch".format(epoch+1))
    generate_and_print_or_save(constant_val_input, constant_val_target,
                               is_save=True, epoch=epoch+1, checkpoint_dir=checkpoint_dir)

  # saving (checkpoint) the model every save_epochs
  if (epoch + 1) % save_model_epochs == 0:
    checkpoint.save(file_prefix = checkpoint_prefix)
    
print('Training Done.')

In [None]:
# generating after the final epoch
display.clear_output(wait=True)
for val_input, valt_target in val_dataset.take(1):
  generate_and_print_or_save(val_input, val_target)

## 11. Restore the latest checkpoint

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## 12. Display an image using the epoch number

In [None]:
display_image(max_epochs, checkpoint_dir=checkpoint_dir)

## 13. Generate a GIF of all the saved images.

In [None]:
filename = model_name + '_' + dataset_name + '.gif'
generate_gif(filename, checkpoint_dir)

In [None]:
display.Image(filename=filename + '.png')