# Data Preprocessing

Importing the libraries

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

Image input pipeline

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

path = '/content/drive/MyDrive/GAN_Data/trainingDataNC/truth_small'
path2 = '/content/drive/MyDrive/GAN_Data/trainingDataNC/artifact_10x_small'
BATCH_SIZE = 4
THETA = 1.5
MARGIN = 3
COUNT = 0
EPOCHS = 400

def load_image(image):
  image = tf.io.read_file(image)
  image = tf.image.decode_png(image)
  image = tf.cast(image, tf.float32)
  return image

def normalize(image):
  image = (image/127.5)-1
  return image

def preprocess(image):
  image = load_image(image)
  image = normalize(image)
  return image

truth_dataset = tf.data.Dataset.list_files(path + '/*.png', shuffle=False)
truth_dataset = truth_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
#truth_dataset = truth_dataset.batch(BATCH_SIZE)

artifact_dataset = tf.data.Dataset.list_files(path2 + '/*.png', shuffle=False)
artifact_dataset = artifact_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
#artifact_dataset = artifact_dataset.batch(BATCH_SIZE)

**Helper Functions for Transform Dataset Method**

In [None]:
def crop(image_a, image_t):
  image_a = tf.reshape(image_a, [1,542,542,1])
  image_t = tf.reshape(image_t, [1,542,542,1])
  combined_image = tf.concat([image_a, image_t], axis=0)
  cropped = tf.image.random_crop(combined_image, size=[2, 512, 512, 1])
  return cropped[0], cropped[1]

def resize(image_a, image_t):
  image_a = tf.image.resize(images=image_a, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image_t = tf.image.resize(images=image_t, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image_a, image_t

def random_jittering(image_a, image_t):
  image_a, image_t = resize(image_a, image_t)
  image_a, image_t = crop(image_a, image_t)
  if (tf.random.uniform(shape=[1]) >= 0.5):
    image_a = tf.image.flip_left_right(image_a)
    image_t = tf.image.flip_left_right(image_t)
  return image_a, image_t

**Transform Dataset Method:**
This method applies the random jittering (random crop and flipping) to the images. This helps the network learn how to better generalize and helps to prevent overfitting.

In [None]:
def transform_dataset(truth_dataset, artifact_dataset):
  temp_truth = []
  temp_art = []
  for t_images, a_images in zip(truth_dataset, artifact_dataset):
    a_images, t_images = random_jittering(a_images, t_images)
    temp_truth.append(t_images)
    temp_art.append(a_images)
  truth_dataset = tf.data.Dataset.from_tensor_slices(temp_truth)
  artifact_dataset = tf.data.Dataset.from_tensor_slices(temp_art)
  return truth_dataset, artifact_dataset

# Creating the Models

Defining the discriminator

In [None]:
def conv_layer(filters, strides, input_layer, apply_batchnorm=True):
  init = tf.random_normal_initializer(0. , 0.02)
  x = tf.keras.layers.Conv2D(filters, (3,3), strides, padding='same', kernel_initializer=init)(input_layer)
  if apply_batchnorm:
    x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.LeakyReLU(0.2)(x)
  return x


def create_discriminator_model():
  init = tf.random_normal_initializer(0. , 0.02)
  input = tf.keras.layers.Input(shape=[512,512,1])
  x = tf.keras.layers.Conv2D(64, (3,3), strides=(1,1), padding='same', kernel_initializer=init)(input)
  D1 = tf.keras.layers.LeakyReLU(0.2)(x)
  #First perceptual loss D1
  x = conv_layer(128, (2,2), D1)
  x = conv_layer(128, (1,1), x)
  D2 = conv_layer(256, (2,2), x)
  #Second perceptual loss D2
  x = conv_layer(256, (1,1), D2)
  D3 = conv_layer(512, (2,2), x)
  #Third perceptual loss D3
  x = conv_layer(512, (1,1), D3)
  D4 = conv_layer(512, (2,2), x)
  #Fourth perceptual loss D4
  x = conv_layer(8, (2,2), D4, apply_batchnorm=False)
  x = tf.keras.layers.Flatten()(x)
  output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

  model = tf.keras.Model(inputs=[input], outputs=[output, D1, D2, D3, D4])
  return model


Discriminator Loss

In [None]:
def get_discriminator_loss(percep_loss, real_output, fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  binary_loss = -1*THETA*(real_loss + fake_loss)
  percep_loss = MARGIN - percep_loss
  if percep_loss < 0:
    percep_loss = 0
  return binary_loss + percep_loss

Defining the generator

In [None]:
def encoder(filters, prev, batchnorm=True):
  init = tf.random_normal_initializer(0. , 0.02)
  x= tf.keras.layers.Conv2D(filters, (3,3), strides=(2,2), kernel_initializer=init, padding='same')(prev)
  if batchnorm:
    x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
  return x

def decoder(filters, prev, skip, dropout=False):
  init = tf.random_normal_initializer(0. , 0.02)
  x = tf.keras.layers.Conv2DTranspose(filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(prev)
  x = tf.keras.layers.BatchNormalization()(x)
  if dropout:
    x = tf.keras.layers.Dropout(0.5)(x)
  x = tf.keras.layers.Concatenate()([x, skip])
  x = tf.keras.layers.Activation('relu')(x)
  return x

def create_generator_model():
  init = tf.random_normal_initializer(0. , 0.02)
  input = tf.keras.layers.Input(shape=[512, 512, 1])
  enc0 = encoder(64, prev=input, batchnorm=False)
  enc1 = encoder(128, prev=enc0)
  enc2 = encoder(256, prev=enc1)
  enc3 = encoder(512, prev=enc2)
  enc4 = encoder(512, prev=enc3)

  mid = tf.keras.layers.Conv2D(512, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(enc4)
  mid = tf.keras.layers.BatchNormalization()(mid)
  mid = tf.keras.layers.LeakyReLU(alpha=0.2)(mid)

  dec0 = decoder(512, prev=mid, skip=enc4, dropout=True)
  dec1 = decoder(256, prev=dec0, skip=enc3, dropout=True)
  dec2 = decoder(128, prev=dec1, skip=enc2, dropout=True)
  dec3 = decoder(64, prev=dec2, skip=enc1)
  x = tf.keras.layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(dec3)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.Activation('relu')(x)

  x = tf.keras.layers.Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(x)
  output = tf.keras.layers.Activation('tanh')(x)

  model = tf.keras.Model(inputs=[input], outputs=[output])

  return model

Generator Loss Function

In [None]:
def get_generator_loss(percep_loss, fake_output):
  gan_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  return THETA*gan_loss + percep_loss

Perceptual Loss function

In [None]:
def get_perceptual_loss(tD1, tD2, tD3, tD4, gD1, gD2, gD3, gD4):
  D1_loss = tf.reduce_mean(tf.abs(tD1-gD1))
  D2_loss = tf.reduce_mean(tf.abs(tD2-gD2))
  D3_loss = tf.reduce_mean(tf.abs(tD3-gD3))
  D4_loss = tf.reduce_mean(tf.abs(tD4-gD4))
  return 5*D1_loss + 1.5*D2_loss + 1.5*D3_loss + D4_loss  

# Training

Train Function

In [None]:
def train(truth_dataset, artifact_dataset, epochs, count):
  global COUNT
  for i in range(epochs):
    print('Epoch: ', i)
    if i % 40 == 0:
      saveModel(generator, i)
    show = True
    for t_images, a_images in zip(truth_dataset, artifact_dataset):
      print('.', end='')
      train_step(truth_image=t_images, artifact_image=a_images, show=show, count=count)
      count +=1
      show = False

Train Step Function

In [None]:
def train_step(artifact_image, truth_image, show, count):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as discriminator_tape:
    generated_images = generator(artifact_image, training=True)
    real_output, tD1, tD2, tD3, tD4 = model_discriminator(truth_image, training=True)
    fake_output, gD1, gD2, gD3, gD4 = model_discriminator(generated_images, training=True)
    if show:
      generated_images ,_ ,_ ,_ = tf.unstack(generated_images, axis=0)
      generated_images = tf.reshape(generated_images, (512,512))
      artifact_image ,_ ,_ ,_ = tf.unstack(artifact_image, axis=0)
      artifact_image = tf.reshape(artifact_image, (512,512))
      truth_image ,_ ,_ ,_ = tf.unstack(truth_image, axis=0)
      truth_image = tf.reshape(truth_image, (512,512))
      plot_images(generated_images, artifact_image, truth_image)

    percep_loss = get_perceptual_loss(tD1, tD2, tD3, tD4, gD1, gD2, gD3, gD4)
    if count %3 == 0:
      discriminator_loss = get_discriminator_loss(percep_loss, real_output, fake_output)
      disc_points.append(discriminator_loss)
      #print('Disc Loss: ', discriminator_loss)
      gradients_of_disc = discriminator_tape.gradient(discriminator_loss, model_discriminator.trainable_variables)
      discriminator_optimizer.apply_gradients(zip(gradients_of_disc, model_discriminator.trainable_variables))

    gen_loss = get_generator_loss(percep_loss, fake_output)
    #print(gen_loss)
   # print('Gen Loss: ', gen_loss)
    gradients_of_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_gen, generator.trainable_variables))
    gen_points.append(gen_loss)
    

Save the Model

In [None]:
def saveModel(generator, epoch):
  name = 'PAN_e' + str(epoch) +'.h5'
  generator.save(name)
  !cp $name "/content/drive/MyDrive/GAN_Data/Models/PAN_test"

Plot the Images

In [None]:
def plot_images(prediction, input, target):
  plt.figure(figsize=(15,15))
  display_list = [input, target, prediction]
  title = ['Artifacted Image', 'Ground Truth Image', 'Generated Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i]*0.5 +0.5, cmap='gray')
    plt.axis('off')
  plt.show()

In [None]:
def plot_loss(gen_points, disc_points):
  x_gen = []
  x_disc= []
  for i in range(0, len(gen_points)):
    x_gen.append(i)
  for j in range(0, len(disc_points)):
    x_disc.append(j)
  plt.plot(x_gen, gen_points, label = 'gen')
  plt.plot(x_disc, disc_points, label = 'disc')
  plt.ylabel('Loss')
  plt.xlabel('Iteration')
  plt.title('Loss values over time')
  plt.legend()
  plt.plot()

# "Main" Function

In [None]:
truth_dataset, artifact_dataset = transform_dataset(truth_dataset, artifact_dataset)
truth_dataset = truth_dataset.batch(BATCH_SIZE)
artifact_dataset = artifact_dataset.batch(BATCH_SIZE)
#It is important to batch the datasets again after transforming the images to get the desired shape for each image i.e. [1, 512, 512, 1]

model_discriminator = create_discriminator_model()
discriminator_optimizer = tf.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
print('Created Discriminator!')
tf.keras.utils.plot_model(model_discriminator, show_shapes=True, dpi=64)

generator = create_generator_model()
generator_optimizer = tf.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

print('Created Generator!')

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

gen_points = []
disc_points = []
#tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
train(truth_dataset, artifact_dataset, EPOCHS, COUNT)

generator.save('PAN_FIN.h5')
!cp PAN_FIN.h5 "/content/drive/MyDrive/GAN_Data/Models/PAN_test"
