<a href="https://colab.research.google.com/github/wayne0git/tensorflow_basic/blob/main/example/pix2pix_gan_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pix2Pix GAN for Image Translation
Ref - https://pyimagesearch.com/2022/07/27/image-translation-with-pix2pix/

## Import library

In [1]:
import pathlib
import os

In [2]:
import matplotlib.pyplot as plt

from matplotlib.pyplot import subplots

In [3]:
import tensorflow as tf
AUTO = tf.data.AUTOTUNE

from tensorflow.keras import Model
from tensorflow.keras import Input

from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dropout

from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.losses import MeanAbsoluteError

from tensorflow.keras.optimizers import Adam

from tensorflow.keras.callbacks import Callback

from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.utils import get_file

## Parameter

In [4]:
# name of the dataset we will be using 
DATASET = "cityscapes"

# build the dataset URL
DATASET_URL = f"http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{DATASET}.tar.gz"

# dataset specs
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256
IMAGE_CHANNELS = 3

In [5]:
# define the batch size
TRAIN_BATCH_SIZE = 32
INFER_BATCH_SIZE = 8

# training specs
LEARNING_RATE = 2e-4
EPOCHS = 1
STEPS_PER_EPOCH = 100

In [6]:
# path to our base output directory
BASE_OUTPUT_PATH = "outputs"

# GPU training pix2pix model paths
GENERATOR_MODEL = os.path.join(BASE_OUTPUT_PATH, "models", "generator")

# define the path to the inferred images and to the grid image
BASE_IMAGES_PATH = os.path.join(BASE_OUTPUT_PATH, "images")
GRID_IMAGE_PATH = os.path.join(BASE_IMAGES_PATH, "grid.png")

## Data Preparation

In [7]:
def load_image(imageFile):
	# read and decode an image file from the path
	image = tf.io.read_file(imageFile)
	image = tf.io.decode_jpeg(image, channels=3)

	# calculate the midpoint of the width and split the combined image into input mask and real image 
	width = tf.shape(image)[1]
	splitPoint = width // 2
	inputMask = image[:, splitPoint:, :]
	realImage = image[:, :splitPoint, :]

	# convert both images to float32 tensors and convert pixels to the range of -1 and 1
	inputMask = tf.cast(inputMask, tf.float32)/127.5 - 1
	realImage = tf.cast(realImage, tf.float32)/127.5 - 1

	# return the input mask and real label image
	return (inputMask, realImage)

In [8]:
def random_jitter(inputMask, realImage, height, width):
	# upscale the images for cropping purposes
	inputMask = tf.image.resize(inputMask, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
	realImage = tf.image.resize(realImage, [height, width],	method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

	# return the input mask and real label image
	return (inputMask, realImage)

In [9]:
class ReadTrainExample(object):
	def __init__(self, imageHeight, imageWidth):
		self.imageHeight = imageHeight
		self.imageWidth = imageWidth
	
	def __call__(self, imageFile):
		# read the file path and unpack the image pair
		inputMask, realImage = load_image(imageFile)

		# perform data augmentation
		# upscale the image and add random artifacts to our image 
		(inputMask, realImage) = random_jitter(inputMask, realImage, self.imageHeight+30, self.imageWidth+30)

		# reshape the input mask and real label image
		inputMask = tf.image.resize(inputMask, [self.imageHeight, self.imageWidth])
		realImage = tf.image.resize(realImage, [self.imageHeight, self.imageWidth])

		# return the input mask and real label image
		return (inputMask, realImage)

In [10]:
class ReadTestExample(object):
	def __init__(self, imageHeight, imageWidth):
		self.imageHeight = imageHeight
		self.imageWidth = imageWidth

	def __call__(self, imageFile):
		# read the file path and unpack the image pair
		(inputMask, realImage) = load_image(imageFile)

		# reshape the input mask and real label image
		inputMask = tf.image.resize(inputMask, [self.imageHeight, self.imageWidth])
		realImage = tf.image.resize(realImage, [self.imageHeight, self.imageWidth])

		# return the input mask and real label image
		return (inputMask, realImage)

In [11]:
def load_dataset(path, batchSize, height, width, train=False):
	# check if this is the training dataset
	if train:
		dataset = tf.data.Dataset.list_files(str(path/"train/*.jpg"))
		dataset = dataset.map(ReadTrainExample(height, width), num_parallel_calls=AUTO)
	# otherwise, we are working with the test dataset
	else:
		dataset = tf.data.Dataset.list_files(str(path/"val/*.jpg"))
		dataset = dataset.map(ReadTestExample(height, width), num_parallel_calls=AUTO)

	# shuffle, batch, repeat and prefetch the dataset
	dataset = dataset.shuffle(batchSize * 2).batch(batchSize).repeat().prefetch(AUTO)

	# return the dataset
	return dataset

In [12]:
# download the cityscape training dataset 
pathToZip = get_file(fname=f"{DATASET}.tar.gz", origin=DATASET_URL,	extract=True)
pathToZip  = pathlib.Path(pathToZip)
path = pathToZip.parent/DATASET

Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/cityscapes.tar.gz


In [13]:
# create dataset
trainDs = load_dataset(path=path, train=True, batchSize=TRAIN_BATCH_SIZE, height=IMAGE_HEIGHT, width=IMAGE_WIDTH)
testDs = load_dataset(path=path, train=False, batchSize=INFER_BATCH_SIZE, height=IMAGE_HEIGHT, width=IMAGE_WIDTH)

## Create Model

In [14]:
class Pix2Pix(object):
	def __init__(self, imageHeight, imageWidth):
		self.imageHeight = imageHeight
		self.imageWidth = imageWidth

	def generator(self):
		# initialize the input layer (256*256*3)
		inputs = Input([self.imageHeight, self.imageWidth, 3])
  
		# down Layer 1 (d1) => final layer 1 (f1) (128*128*32)
		d1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
		d1 = Dropout(0.1)(d1)
		f1 = MaxPool2D((2, 2))(d1)

		# down Layer 2 (l2) => final layer 2 (f2) (64*64*64)
		d2 = Conv2D(64, (3, 3), activation="relu", padding="same")(f1)
		f2 = MaxPool2D((2, 2))(d2)

		#  down Layer 3 (l3) => final layer 3 (f3) (32*32*96)
		d3 = Conv2D(96, (3, 3), activation="relu", padding="same")(f2)
		f3 = MaxPool2D((2, 2))(d3)

		# down Layer 4 (l3) => final layer 4 (f4) (16*16*96)
		d4 = Conv2D(96, (3, 3), activation="relu", padding="same")(f3)
		f4 = MaxPool2D((2, 2))(d4)

		# u-bend of the u-bet (16*16*256)
		b5 = Conv2D(96, (3, 3), activation="relu", padding="same")(f4)
		b5 = Dropout(0.3)(b5)
		b5 = Conv2D(256, (3, 3), activation="relu", padding="same")(b5)

		# upsample Layer 6 (u6) (32*32*128)
		u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding="same")(b5)
		u6 = concatenate([u6, d4])
		u6 = Conv2D(128, (3, 3), activation="relu", padding="same")(u6)

		# upsample Layer 7 (u7) (64*64*128)
		u7 = Conv2DTranspose(96, (2, 2), strides=(2, 2), padding="same")(u6)
		u7 = concatenate([u7, d3])
		u7 = Conv2D(128, (3, 3), activation="relu", padding="same")(u7)

		# upsample Layer 8 (u8) (128*128*128)
		u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding="same")(u7)
		u8 = concatenate([u8, d2])
		u8 = Conv2D(128, (3, 3), activation="relu", padding="same")(u8)

		# upsample Layer 9 (u9) (256*256*128)
		u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding="same")(u8)
		u9 = concatenate([u9, d1])
		u9 = Dropout(0.1)(u9)
		u9 = Conv2D(128, (3, 3), activation="relu", padding="same")(u9)

		# final conv2D layer (256*256*3)
		outputLayer = Conv2D(3, (1, 1), activation="tanh")(u9)
	
		# create the generator model
		generator = Model(inputs, outputLayer)

		return generator

	def discriminator(self):
		# initialize input layer according to PatchGAN
		inputMask = Input(shape=[self.imageHeight, self.imageWidth, 3], name="input_image")
		targetImage = Input(shape=[self.imageHeight, self.imageWidth, 3], name="target_image")
  
		# concatenate the inputs (256*256*6)
		x = concatenate([inputMask, targetImage])  

		# add four conv2D convolution layers
		x = Conv2D(64, 4, strides=2, padding="same")(x)  # (128*128*64)
		x = LeakyReLU()(x)

		x = Conv2D(128, 4, strides=2, padding="same")(x)  # (64*64*128)
		x = LeakyReLU()(x)

		x = Conv2D(256, 4, strides=2, padding="same")(x)  # (32*32*256)
		x = LeakyReLU()(x)

		x = Conv2D(512, 4, strides=1, padding="same")(x)  # (32*32*512)

		# add a batch-normalization layer => LeakyReLU => zeropad
		x = BatchNormalization()(x)
		x = LeakyReLU()(x)

		# final conv layer (30*30*1)
		last = Conv2D(1, 3, strides=1)(x)
  
		# create the discriminator model
		discriminator = Model(inputs=[inputMask, targetImage], outputs=last)

		return discriminator

In [15]:
# initialize the generator and discriminator network
pix2pixObject = Pix2Pix(imageHeight=IMAGE_HEIGHT, imageWidth=IMAGE_WIDTH)
generator = pix2pixObject.generator()
discriminator = pix2pixObject.discriminator()

## Train Model

In [16]:
class Pix2PixTraining(Model):
	def __init__(self, generator, discriminator):
		super().__init__()

		self.generator = generator
		self.discriminator = discriminator

	def compile(self, gOptimizer, dOptimizer, bceLoss, maeLoss):
		super().compile()

		# initialize the optimizers for the generator and discriminator
		self.gOptimizer = gOptimizer
		self.dOptimizer = dOptimizer
		
		# initialize the loss functions
		self.bceLoss = bceLoss
		self.maeLoss = maeLoss

	def train_step(self, inputs):
		# grab the input mask and corresponding real images
		(inputMask, realImages) = inputs

		# initialize gradient tapes for both generator and discriminator
		with tf.GradientTape() as genTape, tf.GradientTape() as discTape:
			# generate fake images
			fakeImages = self.generator(inputMask, training=True)

			# discriminator output for real images and fake images
			discRealOutput = self.discriminator([inputMask, realImages], training=True)
			discFakeOutput = self.discriminator([inputMask, fakeImages], training=True)

			# compute the adversarial loss for the generator
			misleadingImageLabels = tf.ones_like(discFakeOutput) 
			ganLoss = self.bceLoss(misleadingImageLabels, discFakeOutput)

			# compute the mean absolute error between the fake and the real images
			l1Loss = self.maeLoss(realImages, fakeImages)

			# compute the total generator loss
			totalGenLoss = ganLoss + (10 * l1Loss)

			# discriminator loss for real and fake images
			realImageLabels = tf.ones_like(discRealOutput)
			realDiscLoss = self.bceLoss(realImageLabels, discRealOutput)

			fakeImageLabels = tf.zeros_like(discFakeOutput)
			generatedLoss = self.bceLoss(fakeImageLabels, discFakeOutput)

			# compute the total discriminator loss
			totalDiscLoss = realDiscLoss + generatedLoss

		# calculate the generator and discriminator gradients
		generatorGradients = genTape.gradient(totalGenLoss, self.generator.trainable_variables)
		discriminatorGradients = discTape.gradient(totalDiscLoss, self.discriminator.trainable_variables)

		# apply the gradients to optimize the generator and discriminator
		self.gOptimizer.apply_gradients(zip(generatorGradients, self.generator.trainable_variables))
		self.dOptimizer.apply_gradients(zip(discriminatorGradients, self.discriminator.trainable_variables))

		# return the generator and discriminator losses
		return {"dLoss": totalDiscLoss, "gLoss": totalGenLoss}

In [17]:
def get_train_monitor(testDs, imagePath, batchSize, epochInterval):
	# grab the input mask and the real image from the testing dataset
	(tInputMask, tRealImage) = next(iter(testDs))

	class TrainMonitor(Callback):
		def __init__(self, epochInterval=None):
			self.epochInterval = epochInterval

		def on_epoch_end(self, epoch, logs=None):
			if self.epochInterval and epoch % self.epochInterval == 0:
				# get the pix2pix prediction
				tPix2pixGenPred = self.model.generator.predict(tInputMask)
				(fig, axes) = subplots(nrows=batchSize, ncols=3, figsize=(50, 50))

				# plot the predicted images 
				for (ax, inp, pred, tgt) in zip(axes, tInputMask, tPix2pixGenPred, tRealImage):
					# plot the input mask image
					ax[0].imshow(array_to_img(inp))
					ax[0].set_title("Input Image")

					# plot the predicted Pix2Pix image
					ax[1].imshow(array_to_img(pred))
					ax[1].set_title("Pix2Pix Prediction")

					# plot the ground truth
					ax[2].imshow(array_to_img(tgt))
					ax[2].set_title("Target Label")

				plt.savefig(f"{imagePath}/{epoch:03d}.png")
				plt.close()
	
	# instantiate a train monitor callback
	trainMonitor = TrainMonitor(epochInterval=epochInterval)

	return trainMonitor    

In [18]:
tf.random.set_seed(42)

In [19]:
# compile the pix2pix model
pix2pixModel = Pix2PixTraining(generator=generator,	discriminator=discriminator)
pix2pixModel.compile(dOptimizer=Adam(learning_rate=LEARNING_RATE),
                     gOptimizer=Adam(learning_rate=LEARNING_RATE),
                     bceLoss=BinaryCrossentropy(from_logits=True),
                     maeLoss=MeanAbsoluteError())

In [20]:
# check whether output model directory exists
if not os.path.exists(BASE_OUTPUT_PATH):
	os.makedirs(BASE_OUTPUT_PATH)

In [21]:
# check whether output image directory exists, if it doesn't, then create it
if not os.path.exists(BASE_IMAGES_PATH):
	os.makedirs(BASE_IMAGES_PATH)

In [None]:
# train the pix2pix model
callbacks = [get_train_monitor(testDs, epochInterval=10, imagePath=BASE_IMAGES_PATH, batchSize=INFER_BATCH_SIZE)]
pix2pixModel.fit(trainDs, epochs=EPOCHS, callbacks=callbacks, steps_per_epoch=STEPS_PER_EPOCH)

In [None]:
# save the pix2pix generator
pix2pixModel.generator.save(GENERATOR_MODEL)