Importing the necessary libraries

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import math

Input Pipeline

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

path = '/content/drive/MyDrive/GAN_Data/trainingDataNC/truth_small'
path2 = '/content/drive/MyDrive/GAN_Data/trainingDataNC/artifact_10x_small'
BATCH_SIZE = 1
THETA = 1
COUNT = 0
EPOCHS= 50
LR = 0.01

def load_image(image):
  image = tf.io.read_file(image)
  image = tf.image.decode_png(image)
  image = tf.cast(image, tf.float32)
  return image

def normalize(image):
  image = (image/127.5)-1
  return image

def preprocess(image):
  image = load_image(image)
  image = normalize(image)
  return image

truth_dataset = tf.data.Dataset.list_files(path + '/*.png', shuffle=False)
truth_dataset = truth_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
truth_dataset = truth_dataset.batch(BATCH_SIZE)

artifact_dataset = tf.data.Dataset.list_files(path2 + '/*.png', shuffle=False)
artifact_dataset = artifact_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
artifact_dataset = artifact_dataset.batch(BATCH_SIZE)

In [None]:
def crop(image_a, image_t):
  image_a = tf.reshape(image_a, [1,542,542,1])
  image_t = tf.reshape(image_t, [1,542,542,1])
  combined_image = tf.concat([image_a, image_t], axis=0)
  cropped = tf.image.random_crop(combined_image, size=[2, 512, 512, 1])
  return cropped[0], cropped[1]

def resize(image_a, image_t):
  image_a = tf.image.resize(images=image_a, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image_t = tf.image.resize(images=image_t, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image_a, image_t

def random_jittering(image_a, image_t):
  image_a, image_t = resize(image_a, image_t)
  image_a, image_t = crop(image_a, image_t)
  if (tf.random.uniform(shape=[1]) >= 0.5):
    image_a = tf.image.flip_left_right(image_a)
    image_t = tf.image.flip_left_right(image_t)
  return image_a, image_t

In [None]:
def transform_dataset(truth_dataset, artifact_dataset):
  temp_truth = []
  temp_art = []
  for t_images, a_images in zip(truth_dataset, artifact_dataset):
    a_images, t_images = random_jittering(a_images, t_images)
    temp_truth.append(t_images)
    temp_art.append(a_images)
  truth_dataset = tf.data.Dataset.from_tensor_slices(temp_truth)
  artifact_dataset = tf.data.Dataset.from_tensor_slices(temp_art)
  return truth_dataset, artifact_dataset

Creating the ConvNet

In [None]:
def down(filters, prev):
  initializer = tf.keras.initializers.HeNormal()
  x = tf.keras.layers.Conv2D(filters, (3,3), strides=(1,1), padding='same', kernel_initializer=initializer)(prev)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.ReLU()(x)
  return x

In [None]:
def up(filtersFirst,filtersSec, prev, skip):
  initializer = tf.keras.initializers.HeNormal()
  x = tf.keras.layers.Concatenate()([prev, skip])
  x = tf.keras.layers.Conv2D(filtersFirst, (3,3), strides=(1,1), padding='same', kernel_initializer=initializer)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.ReLU()(x)
  x = tf.keras.layers.Conv2D(filtersSec, (3,3), strides=(1,1), padding='same', kernel_initializer=initializer)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.ReLU()(x)
  return x

In [None]:
def create_convNet_model():
  initializer = tf.keras.initializers.HeNormal()
  input = tf.keras.layers.Input(shape=[512, 512, 1])
  x = down(1, input)
  x = down(64, x)
  skip1 = down(64, x)
  x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2))(skip1)
  x = down(64, x)
  skip2 = down(128, x)
  x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2))(skip2)
  x = down(128, x)
  skip3 = down(256, x)
  x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2))(skip3)
  x = down(256, x)
  skip4 = down(512, x)
  x = tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2))(skip4)
  x = down(512, x)
  x = down(1024, x)
  x = down(1024, x)
  x = tf.keras.layers.Conv2DTranspose(1024, (3,3), strides=(2,2), padding='same', kernel_initializer=initializer)(x)
  x = up(1024, 512, x, skip4)
  x = tf.keras.layers.Conv2DTranspose(512, (3,3), strides=(2,2), padding='same',kernel_initializer=initializer)(x)
  x = up(512, 256, x, skip3)
  x = tf.keras.layers.Conv2DTranspose(256, (3,3), strides=(2,2), padding='same', kernel_initializer=initializer)(x)
  x = up(256, 128, x, skip2)
  x = tf.keras.layers.Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=initializer)(x)
  x = up(128, 64, x, skip1)
  x = tf.keras.layers.Conv2D(1, (1,1), strides=(1,1), padding='same', kernel_initializer=initializer)(x)
  x = tf.keras.layers.Add()([input,x])
  output = tf.keras.layers.Activation('tanh')(x)

  model = tf.keras.Model(inputs=[input], outputs=[output])
  return model


In [None]:
def get_loss(generated_image, truth_image):
  loss = tf.reduce_mean((generated_image - truth_image)**2)
  return loss

In [None]:
def log_decay(rate, epochs):
  if epochs != 0:
    rate = -0.1*math.log10(CHANGE-RATE) #.795, .977
    if rate < 0.001:
      rate = 0.001
  return rate

Training

In [None]:
def train(truth_dataset, artifact_dataset, epochs):
  global COUNT
  for i in range(epochs): #Outer loops is for number of epochs
    print('Epoch: ', i)
    optimizer = tf.keras.optimizers.SGD(learning_rate=log_decay(LR, epochs), momentum=0.99, clipvalue=0.1) #This is only here because I used learning rate decay in this model
    #If that is not needed, can establish optimizer outside of loop
    if i % 10 == 0: #Used to save model every 10 epochs
      saveModel(model, i)
    show = True
    for t_images, a_images in zip(truth_dataset, artifact_dataset): #Inner loop is for every element in the dataset
    #I am using a tf dataset, so I can manually batch the datasets and this loop will get a batch size amount of elements 
      print('.', end='')
      train_step(truth_image=t_images, artifact_image=a_images, show=show)
      show = False

In [None]:
def train_step(artifact_image, truth_image, show):
  with tf.GradientTape() as Conv_tape:
    generated_images = model(artifact_image, training=True) #predicting using the model
    loss = get_loss(generated_images, truth_image) #determines the loss of the prediction
    if show:  #section is only necessary for printing output after every epoch
      generated_images = tf.reshape(generated_images, (512,512))
      artifact_image = tf.reshape(artifact_image, (512,512))
      truth_image = tf.reshape(truth_image, (512,512))
      plot_images(generated_images, artifact_image, truth_image)

    gradients_of_gen = Conv_tape.gradient(loss, model.trainable_variables) #gets the gradients from the loss using the GradientTape
    optimizer.apply_gradients(zip(gradients_of_gen, model.trainable_variables)) #applies the gradients using the optimizer

In [None]:
def plot_images(prediction, input, target):
  plt.figure(figsize=(15,15))
  display_list = [input, target, prediction]
  title = ['Artifacted Image', 'Ground Truth Image', 'Generated Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i]*0.5 +0.5, cmap='gray')
    plt.axis('off')
  plt.show()

In [None]:
def saveModel(generator, epoch):
  if epoch != 0:
    name = 'FBPConv_e' + str(epoch) +'.h5'
    model.save(name)
    !cp $name "/content/drive/MyDrive/GAN_Data/Models"

In [None]:

truth_dataset, artifact_dataset = transform_dataset(truth_dataset, artifact_dataset)
truth_dataset = truth_dataset.batch(BATCH_SIZE)
artifact_dataset = artifact_dataset.batch(BATCH_SIZE)

model = create_convNet_model()
tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)
RATE = (.977-.795)/EPOCHS
CHANGE = .977
#lr_scheduler = tf.keras.optimizers.schedules.ExponentialDeca y(initial_learning_rate=0.01, decay_steps=1, decay_rate=0.743)
optimizer = tf.keras.optimizers.SGD(learning_rate=LR, momentum=0.99, clipvalue=0.1)
train(truth_dataset, artifact_dataset, EPOCHS)

model.save('FBPConv.h5')
!cp FBPConv.h5 "/content/drive/MyDrive/GAN_Data/Models"