In [None]:
# phase-attention-guided cycle-consistent generative adversarial network (pcGAN), implemented using Jupyter Notebook

import keras.backend as K
from keras.models import load_model
from tensorflow import set_random_seed
from models import BASIC_D,UNET_G_phaseatt_ds_deep,UNET_G
from keras.optimizers import RMSprop, SGD, Adam
import tensorflow as tf
import numpy as np
from utils import cycle_variables, read_image, minibatchAB, minibatch, rescale, load_data, G
from loss import D_loss
from PIL import Image
from random import randint, shuffle
import glob
import skimage.io as io
import os
import time
import random
import scipy.misc


seed = 11
random.seed(seed)
np.random.seed(seed)
set_random_seed(seed)

K.set_image_data_format('channels_last')
channel_axis=-1
channel_first = False

# Set parameters
nc_in = 3
nc_out = 3
ngf = 64
ndf = 64
use_lsgan = True
lambd1 = 10 if use_lsgan else 100
lambd2 = 0.2

# Image size and batch size
loadSize = 512
imageSize = 512
batchSize = 1

# Learning rate
lrD = 2e-4
lr_decay_D = 0.00001
lrG = 2e-4
lr_decay_G = 0.00001


In [None]:
# pcGAN initialization
netGA = UNET_G(imageSize, nc_out, nc_in, ngf)
netGB = UNET_G_phaseatt_ds_deep(imageSize, nc_in, nc_out, ngf)
netDA = BASIC_D(nc_in, ndf, use_sigmoid = not use_lsgan)
netDB = BASIC_D(nc_out, ndf, use_sigmoid = not use_lsgan)


real_A, fake_B, rec_A, indentity_A, cycleA_generate = cycle_variables(netGB, netGA)
real_B, fake_A, rec_B, indentity_B, cycleB_generate = cycle_variables(netGA, netGB)

# Loss function
loss_DA, loss_GA, loss_cycA, loss_msSSIMA = D_loss(netDA, real_A, fake_A, fake_B, rec_A)
loss_DB, loss_GB, loss_cycB, loss_msSSIMB = D_loss(netDB, real_B, fake_B, fake_A, rec_B)
loss_cyc = loss_cycA + loss_cycB
loss_msSSIM = loss_msSSIMA + loss_msSSIMB

loss_G = loss_GA + loss_GB + lambd1 * loss_cyc + lambd2 * loss_msSSIM
loss_D = loss_DA+loss_DB

weightsD = netDA.trainable_weights + netDB.trainable_weights
weightsG = netGA.trainable_weights + netGB.trainable_weights


training_updates = Adam(lr=lrD, beta_1=0.5, decay=lr_decay_D).get_updates(weightsD,[], loss_D)
netD_train = K.function([real_A, real_B],[loss_DA/2, loss_DB/2], training_updates)
training_updates = Adam(lr=lrG, beta_1=0.5, decay=lr_decay_G).get_updates(weightsG,[], loss_G)
netG_train = K.function([real_A, real_B], [loss_GA, loss_GB, loss_cyc, loss_msSSIM], training_updates)


In [None]:
# Training data location
data = "data"
train_A = glob.glob('{}/trainA/*.tif'.format(data))
train_B = glob.glob('{}/trainB/*.tif'.format(data))

assert len(train_A) and len(train_B)


In [None]:
# Training process
t0 = time.time()
niter = 30                      # Iteration for training
gen_iterations = 0
epoch = 0
errSSIM_sum = errCyc_sum = errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0

display_iters = 50
train_batch = minibatchAB(train_A, train_B, batchSize)

os.mkdir('ckpts')
while epoch < niter:
    epoch_ori = epoch
    epoch, A, B = next(train_batch)

    emp = A[0,:,:,0]
    intens = A[0,:,:,1]
    phase = A[0,:,:,2]
    errDA, errDB  = netD_train([A, B])
    errDA_sum +=errDA
    errDB_sum +=errDB

    errGA, errGB, errCyc, errSSIM = netG_train([A, B])
    errGA_sum += errGA
    errGB_sum += errGB
    errCyc_sum += errCyc
    errSSIM_sum += errSSIM
    gen_iterations+=1
    if gen_iterations%display_iters==0:

        print('[%d/%d][%d] Loss_D: %f %f Loss_G: %f %f loss_cyc %f loss_msSSIM %f'
        % (epoch, niter, gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,
           errGA_sum/display_iters, errGB_sum/display_iters, 
           errCyc_sum/display_iters, errSSIM_sum/display_iters), time.time()-t0)
        _, A, B = train_batch.send(4)

        errSSIM_sum = errCyc_sum = errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0

        if epoch%5 == 0:
            netGB.save(f'ckpts/netGB_pcGAN_epoch{epoch}.h5')
            netGA.save(f'ckpts/netGA_pcGAN_epoch{epoch}.h5')
            netDA.save(f'ckpts/netDA_pcGAN_epoch{epoch}.h5')
            netDB.save(f'ckpts/netDB_pcGAN_epoch{epoch}.h5')

In [None]:
# Testing section
test_A = glob.glob('data/test/*.tif')     # Folder name for testing
Num_TestImg = 1000                        # Image number in the folder
test_outputFolder = 'testOutput'         # Output folder name


test_B = test_A
test_batch = minibatchAB(test_A, test_B, 1)

os.mkdir(test_outputFolder)
for i in range(Num_TestImg):
    epochtest, Atest, Btest = next(test_batch)
    rA = G(cycleA_generate, Atest)
    im1=np.reshape(rA[0], (imageSize,imageSize,3))
    tf.keras.preprocessing.image.save_img(test_outputFolder+'/'+str(i+1)+'.tif', im1)