<a href="https://colab.research.google.com/github/vebrahimi1990/GAN-in-Colab/blob/main/Denoising_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Denoinsg Super-Resolution Images with GAN**

This is a short tutorial on how to use GAN (Generative Adversarial Network) for image denoising. In this tutorial you will learn how to upload your data set in Google Colab and use it to train a GAN and later test it on your test data set.

To use this notebook, first, you need to set up a google drive account for yourself. It is very easy; in case you already have a google account. Just click on this link (https://drive.google.com/drive/my-drive ) and sign in with your Google username and password. 

After creating your Google drive, it is time to create a folder to copy your data set in it. This is a very crucial part of using our Colab Notebook. For denoising you need to have two pairs of data. One pair is the noisy (low SNR) image, and the other is your ground truth (high SNR) image from the same area (FOV). Please, make sure that each pair has the same name as it is shown here:

    (Folder) dataset
      (Folder) train data 
        (Folder)Low SNR: 1.tif, 2.tif, 3.tif, …
        (Folder)GT: 1.tif, 2.tif, 3.tif, …
    (Folder) test data 
        (Folder)Low SNR: 1.tif, 2.tif, 3.tif, …
        (Folder)GT: 1.tif, 2.tif, 3.tif, …
For instance, 1.tif in Low SNR folder and GT folder should show the same area. This needs to be done to make it easier to import the data in the Colab Notebook. It is always better to provide more data for training a deep learning network, however, there is a trade-off between the size of the dataset and the training time. We suggest the training set contains at least 15 images and each image shows an area of 20 micron*20micron. 

After preparing the folder containing all the training and test data, you can simply drag and drop it in your Google drive. 

The next step is to make sure that you have mounted your Google Drive in the Colab Notebook. To do that simply click on the Folder icon on the left side of this screen and push the button **Mount Drive**. 

Before start to use our Notebook, there is another crucial step. You need to make sure you have access to a GPU to accelerate training our model. Training deep learning networks with CUP can be very slow. To use GPU, simply click on **Runtime** menu at the top of this screen, click on **Change runtime type** and choose **GPU** as your hardware accelerator, then click on the save button and wait for a few seconds until a GPU is allocated to you. 

At this point you should be able to start using the Notebook. A Collab Notebook consists of multiple blocks which are called cells. Each cell contains a piece of code. There is play button at the top left of each cell and by pushing that button you can run the cell. There is a short description at the top of each cell which tells you why you need to run that cell. Now, please start playing with this Notebook and enjoy the power of deep learning for your denoising application.

**Good luck!**


# **Cell #1**

Please, run this cell to import all the necessary tools for building our deep learning network. These tools are used in the following blocks, therefore, they need to be imported first. 

In [None]:
#@title Importing the necessary libraries  { form-width: "50%" }
import os
import random
import tensorflow as tf
from tensorflow import keras
import numpy as np
from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout,LayerNormalization, Flatten, LeakyReLU, ReLU
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from keras.applications.vgg19 import VGG19
from keras import layers
import matplotlib.pyplot as plt
import time
from skimage.io import imread
import time
from IPython import display
from tifffile import imsave

# **Cell #2**
This is a quick test to make sure we have access to a GPU. If you get an error please go back to the Runtime and make sure you choose GPU as your accelerator. 

In [None]:
#@title Make sure you have access to a GPU { form-width: "50%" }
assert len(tf.config.list_physical_devices('GPU')) > 0

# **Cell #3**

Now, it is time to import your training set to the Colab Notebook. To do so, please copy and past your training file path in front of **image_dr:** such as:

 /**content/drive/My Drive/**dataset/training

 After running the cell you should be able to see the names of your images in your training set. 

 There are a few parameters that need to be set in this cell. First, please import your image size. This number shows how many pixels are in your original dataset. For instance, 1024 means that your images are 1024*1024. 

The next parameter is patch_size. When we train a deep learning network for denoising, we need to creat smaller patches from our large size training data. This is because we can enlarge our dataset and prevent memory overflowing. a patch size of 128 is a reasonable number but you can change it to bigger or smaller. Also, it is better that the number be a power of 2. 

The last parameter in this cell is the number of patches that we want to creat from each image. This number should be chosen based on the size of your original image size and it should be 4,16,64,256,... .

In [None]:
#@title Directory to the training set { form-width: "50%" }
image_dr = "/content/drive/My Drive/data/Fast_STED_Confocal_vs_STED/STED" #@param {type:"string"}
image_size = 1024 #@param {type:"integer"}
image_list = os.listdir(image_dr+"/Average")
print('The image file names are:', *image_list,sep='\n')
M = len(image_list)
GT = np.empty((M,1024,1024,1),dtype=np.float32)
low = np.empty((M,1024,1024,1),dtype=np.float32)

for i in range(M):
  img_GT = load_img(image_dr + "/Average/"+image_list[i], color_mode="grayscale")
  img_low = load_img(image_dr + "/1frame/"+image_list[i], color_mode="grayscale")

  GT[i] = img_to_array(img_GT).astype(np.float32)
  low[i] = img_to_array(img_low).astype(np.float32)

patch_size = 64 #@param {type:"integer"}
n_patches_per_image =  16#@param {type:"integer"}
n_patches_per_row   = np.sqrt(n_patches_per_image)
n_patches_per_row = int(n_patches_per_row)

rr = np.floor(np.linspace(0,image_size-patch_size-1,n_patches_per_row))
rr.astype(int)
cc = rr

xx = np.empty((M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)
yy = np.empty((M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)

X = np.empty((4*M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)
Y = np.empty((4*M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)

count = 0
for i in range(M):
  for j in range(n_patches_per_row):
    for k in range(n_patches_per_row):
      xx[count] = low[i,j:j+patch_size,k:k+patch_size,:]
      xx[count] = xx[count]/xx[count].max()
      yy[count] = GT[i,j:j+patch_size,k:k+patch_size,:]
      yy[count] = yy[count]/yy[count].max()
      count+=1

X[0:count,:,:,:]=xx
X[count:2*count,:,:,:]=np.flip(xx,axis=1)
X[2*count:3*count,:,:,:]=np.flip(xx,axis=2)
X[3*count:4*count,:,:,:]=np.flip(xx,axis=(1,2))

Y[0:count,:,:,:]=yy
Y[count:2*count,:,:,:]=np.flip(yy,axis=1)
Y[2*count:3*count,:,:,:]=np.flip(yy,axis=2)
Y[3*count:4*count,:,:,:]=np.flip(yy,axis=(1,2))


aa = np.linspace(0,len(X)-1,len(X))
random.shuffle(aa)
aa = aa.astype(int)

XX = np.empty((4*M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)
YY = np.empty((4*M*n_patches_per_image,patch_size,patch_size,1),dtype=np.float32)

for i in range(len(X)):
  XX[i,:,:,:] = X[aa[i],:,:,:]
  YY[i,:,:,:] = Y[aa[i],:,:,:]

# Split train and valid
ratio = 0.8
M1 = np.floor(X.shape[0]*ratio).astype(np.int32)
X_train = XX[0:M1,:,:,:]
Y_train = YY[0:M1,:,:,:]
X_valid = XX[M1::,:,:,:]
Y_valid = YY[M1::,:,:,:]

print('The training set shape is:',X_train.shape)
print('The validation set shape is:',X_valid.shape)

# **Cell #4**

You can randomly see pair of your dataset by running this cell. 

In [None]:
#@title Plot pairs of images in your dataset { form-width: "50%" }
variable_name = "a"
ix = random.randint(0, len(X))
fig = plt.figure(figsize=(15,15))
fig.add_subplot(1,2, 1)
cmap=plt.get_cmap('magma')
plt.imshow(X[ix].squeeze(),cmap)
plt.axis('off')

fig.add_subplot(1,2, 2)
cmap=plt.get_cmap('magma')
plt.imshow(Y[ix].squeeze(),cmap)
plt.axis('off')

# **Cell #5**

Please, specifiy the number of the convolutional filters, number of the layers, and the learning rate of your network. You can also use the preset parameters. Increasing the number of filters or the number of layers can add to the prediction accuracy, but there is always a trade-off between the size of a network and the GPU memory. Finding the optimal learning rate is very tricky, so you can just use the pre-set learning rate. A very small learning rate can make learning learning process slow, while a large learning rate can lead to learning instability. 

After running this cell, you should be able to see a plot of your generator model. 

In [None]:
#@title Define the hyperparameters of your model { form-width: "50%" }
num_filters = 16 #@param {type:"integer"}
num_layers =  2#@param {type:"integer"}
learning_rate =  5e-5#@param {type:"number"}




vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet',input_tensor=Input(shape=(patch_size, patch_size, 3)))
inter_vgg1 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[1].name).output)
inter_vgg2 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[2].name).output)
inter_vgg4 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[4].name).output)
inter_vgg5 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[5].name).output)
inter_vgg7 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[7].name).output)
inter_vgg8 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[8].name).output)
inter_vgg12 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[12].name).output)
inter_vgg17 = Model(inputs = vgg.input, outputs=vgg.get_layer(vgg.layers[17].name).output)
inter_vgg = [inter_vgg1,inter_vgg2,inter_vgg4,inter_vgg5,inter_vgg7,inter_vgg8,inter_vgg12,inter_vgg17]

def nmse_loss(pred,gt):
  mse = tf.keras.metrics.mean_squared_error(pred,gt)
  mse = tf.math.reduce_sum(mse,axis=(1,2))
  norm = tf.norm(gt,axis=(1,2))
  norm = tf.squeeze(norm)
  norm = tf.pow(norm,2)
  norm = tf.math.reduce_sum(norm)
  nmse = tf.math.divide(mse,norm)
  nmse = tf.math.reduce_mean(nmse)
  return nmse

def fft_loss(pred,gt):
  pred = tf.transpose(pred, perm=[0, 3, 1, 2])
  gt   = tf.transpose(gt, perm=[0, 3, 1, 2])

  pred_fft = tf.signal.fftshift(tf.signal.rfft2d(pred))
  gt_fft   = tf.signal.fftshift(tf.signal.rfft2d(gt))

  pred_fft = tf.transpose(pred_fft, perm=[0, 2, 3, 1])
  gt_fft   = tf.transpose(gt_fft, perm=[0, 2, 3, 1])

  ft_loss = nmse_loss(pred_fft,gt_fft)
  ft_loss = tf.cast(ft_loss,tf.float32)
  return ft_loss


def percep_loss(pred,gt):
  ploss = 0
  nmse = nmse_loss(pred,gt)
  ft_loss = fft_loss(pred,gt)
  ssim_loss = 1.0-tf.math.reduce_mean(tf.image.ssim(pred,gt,max_val=1))
  pred = tf.image.grayscale_to_rgb(pred)
  gt   = tf.image.grayscale_to_rgb(gt)
  for i in range(8):
    vgg_pred = inter_vgg[i](pred)
    vgg_gt = inter_vgg[i](gt)
    ploss = ploss+nmse_loss(vgg_pred,vgg_gt)
  ploss = 1000*(ploss)+0.1*ssim_loss+5*nmse+20*ft_loss
  return ploss

def discriminator_loss(real_output, fake_output):
    real_loss = -tf.math.log(real_output)
    fake_loss = -tf.math.log(1.0-fake_output+tf.math.sign(fake_output-0.5)*1e-8)
    total_loss = tf.math.reduce_mean(real_loss + fake_loss)
    return total_loss

def generator_validation_loss(pred,gt):
  ssim_loss = 1.0-tf.math.reduce_mean(tf.image.ssim(pred,gt,max_val=1))
  nmse = nmse_loss(pred,gt)
  valid_loss = ssim_loss+1e-3*nmse
  return valid_loss


def dis_conv_block(inpt,num_filters=64,kernel_shape=(3,3),strides=(1,1)):
  x = Conv2D(num_filters, kernel_shape, padding="same",strides=strides)(inpt)
  x = BatchNormalization()(x)
  x = LeakyReLU()(x)
  return x

def res_conv_block(inputs,num_filters=32,kernel_shape=(3,3)):
  x = Conv2D(num_filters, kernel_shape, padding="same")(inputs)
  x = BatchNormalization()(x,training=False)
  x = LeakyReLU()(x)
  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = BatchNormalization()(x,training=False)
  x = add([x,inputs])
  return x

def conv_block(inputs,num_filters=32,kernel_shape=(3,3)):
  x = Conv2D(num_filters, kernel_shape, padding="same")(inputs)
  x = BatchNormalization()(x,training=False)
  x = LeakyReLU()(x)
  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = BatchNormalization()(x,training=False)
  return x

def make_discriminator(input_gen,input_gt,num_filters=num_filters,kernel_shape=(3,3)):
  x = concatenate([input_gen,input_gt])
  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = LeakyReLU()(x)
  x = dis_conv_block(x,num_filters,kernel_shape,strides=(2,2))
  x = dis_conv_block(x,2*num_filters,kernel_shape,strides=(1,1))
  x = dis_conv_block(x,2*num_filters,kernel_shape,strides=(2,2))
  x = dis_conv_block(x,4*num_filters,kernel_shape,strides=(1,1))
  x = dis_conv_block(x,4*num_filters,kernel_shape,strides=(2,2))
  x = dis_conv_block(x,8*num_filters,kernel_shape,strides=(1,1))
  x = dis_conv_block(x,8*num_filters,kernel_shape,strides=(2,2))
  x = Flatten()(x)
  x = Dense(256)(x)
  x = LeakyReLU()(x)
  x = Dense(1)(x)
  x = Activation("sigmoid")(x)
  model = Model(inputs=[input_gen,input_gt], outputs=[x])
  return model

def make_generator(inputs,num_filters=num_filters,num_layers=num_layers,kernel_shape=(3,3),dropout=0.4):
  filters = []
  skip_x = []
  for i in range(num_layers):
    filters.append((2**i)*num_filters)
  y = Conv2D(num_filters,kernel_shape, padding = 'same')(inputs)
  y = LeakyReLU()(y)
  x = y
  for f in filters:
    x = res_conv_block(x, f,kernel_shape)
    skip_x.append(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(dropout)(x,training=False)
    x = Conv2D(2*f,kernel_shape, padding = 'same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

  x = res_conv_block(x, 2*filters[-1],kernel_shape)
  filters.reverse()
  skip_x.reverse()

  for i, f in enumerate(filters):
    x = Conv2DTranspose(f,kernel_shape, strides = (2, 2), padding = 'same')(x)
    xs = skip_x[i]
    x = concatenate([x, xs])
    x = LeakyReLU()(x)
    x = Dropout(dropout)(x,training=False)
    x = conv_block(x, f,kernel_shape)

  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = LeakyReLU()(x)
  x = add([x,y])
  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = LeakyReLU()(x)
  x = Conv2D(1, kernel_shape, padding="same")(x)
  model = Model(inputs=[inputs], outputs=[x])
  return model

def make_gan(generator,discriminator,inputs):
  gen_out = generator(inputs)
  dis_out = discriminator([gen_out,inputs])
  model = Model(inputs=inputs,outputs=[dis_out,gen_out])
  return model

gan_lr = learning_rate
dis_opt = keras.optimizers.Adam(learning_rate=gan_lr)
gan_opt = keras.optimizers.Adam(learning_rate=gan_lr)

discrim_input = Input((patch_size,patch_size,1))
discrim_input1 = Input((patch_size,patch_size,1))
gen_input = Input((patch_size,patch_size,1))
gan_input = Input((patch_size,patch_size,1))

dis_model = make_discriminator(discrim_input,discrim_input1,num_filters=64)
gen_model = make_generator(gen_input,num_filters=num_filters,num_layers=num_layers,kernel_shape=(3,3),dropout=0.3)
gan_model = make_gan(gen_model,dis_model,gan_input)

dis_model.compile(optimizer=dis_opt,loss=['binary_crossentropy'], loss_weights=[1])
gan_model.compile(optimizer=gan_opt,loss=['binary_crossentropy',percep_loss],loss_weights=[0.01,1])
gen_model.compile(optimizer=gan_opt,loss=percep_loss)

#gan_model.summary()
#gen_model.summary()
#dis_model.summary()
tf.keras.utils.plot_model(gen_model,show_shapes=True, show_layer_names=True, dpi=40)

# **Cell #6**

In this cell, you can train your network with your provided training dataset. To do so, first specify two separate directories in your Google Drive for saving your model weights and also your loss values. The extension for saving the model weights and loss values should be **.h5** and **.csv**, respectively. Then you can specifiy the number of epochs and the number of batches. The number of epochs means how many iterations you want to train your network. The number of batches determine how many images from your training set should be used at each iteration within an epoch. For instance if you have 1024 images in your training set and you choose your batch size 64, then your network goes through 16 iterations within each epoch. 

By pushing the play button your network starts the training process. After each epoch, it eveluates the network on a validation data-set and if the validation loss decreases, the model weights will be saved. 

In [None]:
#@title Training loop parameters { form-width: "50%" }
model_save_directory = "/content/drive/My Drive/model-GAN.h5" #@param {type:"string"}
loss_save_directory = "/content/drive/My Drive/model-GAN-loss.csv" #@param {type:"string"}
n_epochs = 100 #@param {type:"integer"}
n_batch =  32#@param {type:"integer"}
def generate_real_samples(low,gt,batch_size):
	ix = np.random.randint(0, low.shape[0]-batch_size, batch_size)
	X1, X2 = low[ix],gt[ix]
	y = np.ones((batch_size, 1))
	return [X1, X2], y

def generate_fake_samples(X1,gen_model,batch_size):
  X = gen_model.predict(X1)
  y = np.zeros((batch_size, 1))
  return X, y

def train(dis_model, gen_model, gan_model,low,gt,low_valid,gt_valid, n_epochs=100, n_batch=32):
  n_steps = np.floor(low.shape[0]/n_batch).astype(np.int32)
  loss = np.zeros((n_epochs,4))
  lossv = np.zeros((n_epochs,1))
  count = 0
  for m in range(n_epochs):
    start = time.time()
    dis_loss1 = 0
    dis_loss2 = 0
    gan_loss  = 0
    val_loss  = 0
    for j in range(n_steps):
      [X1, X2], y_real = generate_real_samples(low,gt,n_batch)
      [X1_valid, X2_valid], y_real_valid = generate_real_samples(low_valid,gt_valid,n_batch)
      X_fakeB, y_fake = generate_fake_samples(X1,gen_model,n_batch)
      dis_loss1 = dis_model.train_on_batch([X1, X2], y_real) +dis_loss1
      dis_loss2 = dis_model.train_on_batch([X1, X_fakeB], y_fake) + dis_loss2
      gan_loss1,_,_  = gan_model.train_on_batch(X1,[y_real,X2])
      gan_loss = gan_loss1 + gan_loss
      val_loss1 = gen_model.test_on_batch(X1_valid,X2_valid)
      val_loss = val_loss1 + val_loss
    val_loss = val_loss/n_steps
    gan_loss = gan_loss/n_steps
    dis_loss1 = dis_loss1/n_steps
    dis_loss2 = dis_loss2/n_steps



    loss[m,0] = dis_loss1
    loss[m,1] = dis_loss2
    loss[m,2] = gan_loss 
    loss[m,3] = val_loss
    lossv[m,0] = val_loss

    display.clear_output(wait=True)

    current_lr = tf.keras.backend.eval(gan_model.optimizer.lr)
    print('learning rate:',current_lr)
    if np.remainder(count+1,10) == 0:
      if current_lr>1e-7:
        update_lr  = current_lr*0.5
        tf.keras.backend.set_value(gan_model.optimizer.learning_rate,update_lr)
        tf.keras.backend.set_value(dis_model.optimizer.learning_rate,update_lr)

    if m==0:
      gen_model.save_weights(model_save_directory,overwrite=True)
    else:
      if val_loss <= np.min(lossv[np.nonzero(lossv)]):
        gen_model.save_weights(model_save_directory,overwrite=True)
        print('model is saved')
        count = 0
      else:
        count = count + 1
    if count==100:
      print('Training is stopped')
      break 
    
    np.savetxt(loss_save_directory, loss, delimiter=",")


    
    print('>%d, d1[%.3f] d2[%.3f] g[%.3f] val[%3f]' % (m+1, dis_loss1, dis_loss2, gan_loss, val_loss))
    print ('Time for epoch {} is {} sec'.format(m + 1, time.time()-start))
    ix = 4
    predictions = gen_model.predict(X_valid[ix:ix+1])
    fig = plt.figure(figsize = (20,15))
    for i in range(predictions.shape[0]):
      plt.subplot(1, 3, 3*i+1)
      plt.imshow(X_valid[i+ix, :, :, 0] , cmap='magma')
      plt.axis('off')
      plt.subplot(1, 3, 3*i+2)
      plt.imshow(predictions[i, :, :, 0] , cmap='magma')
      plt.axis('off')
      plt.subplot(1, 3, 3*i+3)
      plt.imshow(Y_valid[i+ix, :, :, 0] , cmap='magma')
      plt.axis('off')
    plt.show()

  return loss

loss = train(dis_model,gen_model,gan_model,X_train,Y_train,X_valid,Y_valid,n_epochs = n_epochs, n_batch=n_batch)

# **Cell #7**
 
 Here, we can plot our training and validation losses and see if our network is still need to be trained for more epochs or not. You can choose to plot the loss values between **epoch#1** and **epoch#2**. 

In [None]:
#@title Plot training and validation losses { form-width: "50%" }
epoch1 = 40 #@param {type:"integer"}
epoch2 = 100 #@param {type:"integer"}
loss_value_directory = "/content/drive/My Drive/model-GAN-loss.csv" #@param {type:"string"}

loss_file = open(loss_value_directory)
ll = np. loadtxt(loss_file, delimiter=",")
mm = np.linspace(epoch1,epoch2,epoch2-epoch1)
plt.plot(mm,ll[epoch1:epoch2,2],mm,ll[epoch1:epoch2,3])
plt.legend(['training loss','validation loss'])

# **Cell #8**

Now, it is time to have fun and see the results of your trained network. To do so, first you should import your test data set in the same manner that you imported your training dataset. First, copy and paste the directory to your test set, then determine the size of your test images and the number of patches you like to generate per image. 

In [None]:
#@title Directory to the test set { form-width: "50%" }
test_image_dr = "/content/drive/My Drive/data/Fast_STED_Confocal_vs_STED/STED" #@param {type:"string"}
test_image_size = 1024 #@param {type:"integer"}
test_patch_size = 128 #@param {type:"integer"}
test_n_patches_per_image =  64#@param {type:"integer"}
test_image_list = os.listdir(image_dr+"/Average")
print('The image file names are:', *image_list,sep='\n')
Ms = len(image_list)
test_GT = np.empty((Ms,test_image_size,test_image_size,1),dtype=np.float32)
test_low = np.empty((Ms,test_image_size,test_image_size,1),dtype=np.float32)

for i in range(M):
  img_GT = load_img(test_image_dr + "/Average/"+image_list[i], color_mode="grayscale")
  img_low = load_img(test_image_dr + "/1frame/"+image_list[i], color_mode="grayscale")

  test_GT[i] = img_to_array(img_GT).astype(np.float32)
  test_low[i] = img_to_array(img_low).astype(np.float32)


test_n_patches_per_row   = np.sqrt(test_n_patches_per_image)
test_n_patches_per_row = int(test_n_patches_per_row)

test_rr = np.floor(np.linspace(0,test_image_size-test_patch_size-1,test_n_patches_per_row))
test_rr.astype(int)
test_cc = test_rr

test_xx = np.empty((Ms*test_n_patches_per_image,test_patch_size,test_patch_size,1),dtype=np.float32)
test_yy = np.empty((Ms*test_n_patches_per_image,test_patch_size,test_patch_size,1),dtype=np.float32)

test_count = 0
for i in range(Ms):
  for j in range(test_n_patches_per_row):
    for k in range(test_n_patches_per_row):
      test_xx[test_count] = test_low[i,j:j+test_patch_size,k:k+test_patch_size,:]
      test_xx[test_count] = test_xx[test_count]/test_xx[test_count].max()
      test_yy[test_count] = test_GT[i,j:j+test_patch_size,k:k+test_patch_size,:]
      test_yy[test_count] = test_yy[test_count]/test_yy[test_count].max()
      test_count+=1

print('The test set shape is:',test_xx.shape)

# **Cell #9**

In this cell, you are able to import the model that you just trained and test it on your imported test-set. To do so, simply copy and paste the directory to your saved model and run the cell. 

In [None]:
#@title Test a trained model on a test set { form-width: "50%" }
model_directory = "/content/drive/My Drive/model-GAN.h5" #@param {type:"string"}

def make_generator(inputs,num_filters=num_filters,num_layers=num_layers,kernel_shape=(3,3),dropout=0.4):
  filters = []
  skip_x = []
  for i in range(num_layers):
    filters.append((2**i)*num_filters)
  y = Conv2D(num_filters,kernel_shape, padding = 'same')(inputs)
  y = LeakyReLU()(y)
  x = y
  for f in filters:
    x = res_conv_block(x, f,kernel_shape)
    skip_x.append(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(dropout)(x,training=False)
    x = Conv2D(2*f,kernel_shape, padding = 'same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

  x = res_conv_block(x, 2*filters[-1],kernel_shape)
  filters.reverse()
  skip_x.reverse()

  for i, f in enumerate(filters):
    x = Conv2DTranspose(f,kernel_shape, strides = (2, 2), padding = 'same')(x)
    xs = skip_x[i]
    x = concatenate([x, xs])
    x = LeakyReLU()(x)
    x = Dropout(dropout)(x,training=False)
    x = conv_block(x, f,kernel_shape)

  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = LeakyReLU()(x)
  x = add([x,y])
  x = Conv2D(num_filters, kernel_shape, padding="same")(x)
  x = LeakyReLU()(x)
  x = Conv2D(1, kernel_shape, padding="same")(x)
  model = Model(inputs=[inputs], outputs=[x])
  return model


gan_lr = learning_rate
gan_opt = keras.optimizers.Adam(learning_rate=gan_lr)
gen_input = Input((test_patch_size,test_patch_size,1))
gen_model = make_generator(gen_input,num_filters=num_filters,num_layers=num_layers,kernel_shape=(3,3),dropout=0.3)
gen_model.compile(optimizer=gan_opt,loss=percep_loss)


gen_model.load_weights(model_save_directory)

prediction = np.zeros((test_xx.shape[0],test_xx.shape[1],test_xx.shape[2],1))
for i in range(test_xx.shape[0]):
  prediction[i] = gen_model.predict(test_xx[i:i+1,:,:,:])

# **Cell #10**
After predictiong your test set, you can plot the output images and compare them with the ground truth here. 

In [None]:
#@title Plot your prediction results
image_number = 10 #@param {type:"integer"}
#ix = np.random.randint(len(prediction))
ix = image_number
fig = plt.figure(figsize = (20,15))
plt.subplot(1, 3, 1)
plt.imshow(test_xx[ix, :, :, 0] , cmap='magma')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(prediction[ix, :, :, 0] , cmap='magma')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(test_yy[ix, :, :, 0] , cmap='magma')
plt.axis('off')

# **Cell #11**

Finally, it is time to save your test and predicted data set. The images will be transformed to 16 bits and be saved in the same directory as your test data set. 

In [None]:
#@title Save your results { form-width: "50%" }
pred_test = prediction.reshape(prediction.shape[0],test_patch_size,test_patch_size)
X_test = test_xx.reshape(test_xx.shape[0],test_patch_size,test_patch_size)
Y_test = test_yy.reshape(test_yy.shape[0],test_patch_size,test_patch_size)

pred_test = pred_test*(2**16-1)
X_test = X_test*(2**16-1)
Y_test = Y_test*(2**16-1)

pred_test = pred_test.astype(np.uint16)
X_test = X_test.astype(np.uint16)
Y_test = Y_test.astype(np.uint16)

imsave(test_image_dr+'/prediction.tif', pred_test)
imsave(test_image_dr+'/x_test.tif', X_test)
imsave(test_image_dr+'/y_test.tif', Y_test)