In [None]:
import tensorflow as tf
import numpy as np
from skimage import metrics
from skimage.metrics import structural_similarity as ssim
from skimage import data
from skimage import img_as_float
from google.colab import drive
import cv2
#from PIL import Image
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Loading the Test Dataset**

In [None]:
path = '/content/drive/MyDrive/GAN_Data/testData/truth'
path2 = '/content/drive/MyDrive/GAN_Data/testData/artifacted_10x'
BATCH_SIZE = 1

#####Preprocessing#####
def load_image(image):
  image = tf.io.read_file(image)
  image = tf.image.decode_png(image)
  image = tf.cast(image, tf.float32)
  return image

def normalize(image):
  image = (image/127.5)-1
  return image

def preprocess(image):
  image = load_image(image)
  image = normalize(image)
  return image

def crop(image_a, image_t):
  combined_image = tf.concat([image_a, image_t], axis=0)
  cropped = tf.image.random_crop(combined_image, size=[2, 512, 512, 1])
  return cropped[0], cropped[1]

def resize(image_a, image_t):
  image_a = tf.image.resize(images=image_a, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image_t = tf.image.resize(images=image_t, size=[542,542], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image_a, image_t

def random_jittering(image_a, image_t):
  image_a, image_t = resize(image_a, image_t)
  image_a, image_t = crop(image_a, image_t)
  if (tf.random.uniform(shape=[1]) >= 0.5):
    image_a = tf.image.flip_left_right(image_a)
    image_t = tf.image.flip_left_right(image_t)
  return image_a, image_t

truth_dataset = tf.data.Dataset.list_files(path + '/*.png', shuffle=False)
truth_dataset = truth_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
truth_dataset = truth_dataset.batch(BATCH_SIZE)

artifact_dataset = tf.data.Dataset.list_files(path2 + '/*.png', shuffle=False)
artifact_dataset = artifact_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
artifact_dataset = artifact_dataset.batch(BATCH_SIZE)

**Importing the Model**

In [None]:
model_path = '/content/drive/MyDrive/GAN_Data/Models/PAN_test/PAN_e160.h5'
generator = tf.keras.models.load_model(model_path)

**Iterate through the images**

In [None]:
def saveImages(artifact, truth, num):
  art_out.save('Gen_'+ str(count)+'.png')
  files.download('Gen_'+ str(count)+'.png')
  tru_out.save('Truth_'+ str(count) + '.png')
  files.download('Truth_'+ str(count) + '.png')      

**Denormalize Image**

In [None]:
def denormalize(gen_image, a_images, t_images):
  gen_image = tf.reshape(gen_image, (512,512))
  gen_image = ((gen_image+1)/2)*255
  gen_image = np.array(gen_image, dtype=np.uint8)
  gen_out = gen_image
  '''
  for i in np.nditer(gen_image):
    generated_inten.append(i)
  '''

  a_images = tf.reshape(a_images, (512,512))
  a_images = ((a_images+1)/2)*255
  a_images = np.array(a_images, dtype=np.uint8)
  art_out = a_images

  t_images = tf.reshape(t_images, (512,512))
  t_images = ((t_images+1)/2)*255
  t_images = np.array(t_images, dtype=np.uint8)
  tru_out = t_images
  '''
  for i in np.nditer(t_images):
    truth_inten.append(i)
  '''

  return gen_out, art_out, tru_out

**Similar Structure Function**

In [None]:
def similarStructure(artifact, truth, gen):
  artifact = img_as_float(artifact)
  truth = img_as_float(truth)
  gen = img_as_float(gen)
  mse_tru = metrics.mean_squared_error(truth, truth)
  mse_at = metrics.mean_squared_error(truth, artifact)
  mse_gen = metrics.mean_squared_error(truth, gen)
  ssim_t = ssim(truth, truth)
  ssim_at = ssim(truth, artifact)
  ssim_gen = ssim(truth, gen)
  psnr_gen = metrics.peak_signal_noise_ratio(truth, gen)
  psnr_tru = metrics.peak_signal_noise_ratio(truth, truth)
  psnr_at = metrics.peak_signal_noise_ratio(truth, artifact)
  ssim_scores.append(ssim_gen)
  mse_scores.append(mse_gen)
  psnr_scores.append(psnr_gen)

  fig, axs = plt.subplots(nrows = 1, ncols = 3, figsize = (15, 6))
  axes = axs.ravel()
  label = 'MSE: {:.2f}, SSIM: {:.2f}, PSNR: {:.2f}'

  axes[0].imshow(truth, cmap='gray')
  axes[0].set_xlabel(label.format(mse_tru, ssim_t, psnr_tru))
  axes[0].set_title('Ground Truth')

  axes[1].imshow(artifact, cmap='gray')
  axes[1].set_xlabel(label.format(mse_at, ssim_at, psnr_at))
  axes[1].set_title('Artifact')

  axes[2].imshow(gen, cmap='gray')
  axes[2].set_xlabel(label.format(mse_gen, ssim_gen, psnr_gen))
  axes[2].set_title('Generated')
  plt.show()


In [None]:
def plot_hist(truth, generated):
  bins = []
  i = 0
  while i != 256:
    bins.append(i)
    i = i+16
  plt.hist(truth, bins=bins)
  plt.show()
  plt.hist(generated, bins=bins)
  plt.show()

**Main Loop**

In [None]:
ssim_scores = []
mse_scores = []
psnr_scores = []
generated_inten = []
truth_inten = []
for a_images, t_images in zip(artifact_dataset, truth_dataset):
  #a_images, t_images = random_jittering(a_images, t_images)
  '''
  a_images = tf.reshape(a_images, shape=[1, 512, 512, 1])
  t_images = tf.reshape(t_images, shape=[1, 512, 512, 1])
  '''
  gen_images = generator(a_images)
  gen, artifact, truth = denormalize(gen_images, a_images, t_images)
  similarStructure(artifact, truth, gen)
#plot_hist(truth_inten, generated_inten)
length = len(ssim_scores)
sum_num = np.sum(ssim_scores)
print('Average SSIM Score: ' + str(sum_num/length))
sum_num = np.sum(mse_scores)
print('Average MSE Score: ' + str(sum_num/length))
sum_num = np.sum(psnr_scores)
print('Average PSNR Score: ' + str(sum_num/length))
