##### Copyright 2019 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# CycleGAN

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/generative/cyclegan"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/cyclegan.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Set up the input pipeline

Install the [tensorflow_examples](https://github.com/tensorflow/examples) package that enables importing of the generator and the discriminator.

In [0]:
!pip install git+https://github.com/tensorflow/examples.git

In [0]:
import tensorflow as tf
tf.__version__

In [0]:
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import IPython.display as display
from PIL import Image
import numpy as np
import pathlib

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Input Pipeline

This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#load_the_dataset)

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
# data path 
path_image = '/content/drive/My Drive/Course/CS230/CycleGAN/img_v1'
path_label = '/content/drive/My Drive/Course/CS230/CycleGAN/lbl_v5'

# data names
im_list = os.listdir(path_image)
la_list = os.listdir(path_label)

# data full path
image_list = [os.path.join(path_image, ele) for ele in im_list]
label_list = [os.path.join(path_label, ele) for ele in la_list]

# manual split
image_list_train = image_list[:24]
label_list_train = label_list[:30]
image_list_test = image_list[24:]
label_list_test = label_list[30:]

# convert to tensor
image_list_train_tf = tf.constant(image_list_train)
label_list_train_tf = tf.constant(label_list_train)
image_list_test_tf = tf.constant(image_list_test)
label_list_test_tf = tf.constant(label_list_test)

# print length
print(len(image_list_train))
print(len(label_list_train))
print(len(image_list_test))
print(len(label_list_test))

print(label_list_test)

In [0]:
BUFFER_SIZE = 26
BATCH_SIZE = 4
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [0]:
####
def random_crop(image):
  cropped_image = tf.image.random_crop(image, size=[IMG_WIDTH, IMG_HEIGHT, 3])
  return cropped_image

####
def random_jitter(image):
  image = tf.image.resize(image, [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  image = random_crop(image)
  image = tf.image.random_flip_left_right(image)
  return image

####
def decode_images(img):
  img = tf.image.decode_png(img)
  img = tf.image.convert_image_dtype(img, tf.float32)
  return img

def decode_labels(img):
  img = tf.image.decode_png(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.float32)*255.
  return img

####
def process_path_images_train(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_images(img)
  img = random_jitter(img)
  return img

def process_path_labels_train(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_labels(img)
  img = random_jitter(img)
  return img

####
def process_path_images_test(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_images(img)
  img = random_jitter(img)
  return img

def process_path_labels_test(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_labels(img)
  img = random_jitter(img)
  return img

In [0]:
# global random seed
tf.random.set_seed(0)

# dataset of string
image_train_ds = tf.data.Dataset.from_tensor_slices(image_list_train_tf)
label_train_ds = tf.data.Dataset.from_tensor_slices(label_list_train_tf)
image_test_ds = tf.data.Dataset.from_tensor_slices(image_list_test_tf)
label_test_ds = tf.data.Dataset.from_tensor_slices(label_list_test_tf)

# dataset map to arrays
image_train_ds = image_train_ds.map(process_path_images_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
label_train_ds = label_train_ds.map(process_path_labels_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
"""
what is the optimal buffer size for test data???
"""
image_test_ds = image_test_ds.map(process_path_images_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
label_test_ds = label_test_ds.map(process_path_labels_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

In [0]:
sample_image = next(iter(image_train_ds))
sample_label = next(iter(label_train_ds))
print(sample_image.shape)
print(sample_label.shape)

In [0]:
test_sample_image = next(iter(image_test_ds))
test_sample_label = next(iter(label_test_ds))
print(test_sample_image.shape)
print(test_sample_label.shape)

In [0]:
plt.figure(figsize=(12,5))
plt.subplot(121)
plt.title('Image')
plt.imshow(sample_image[0])

plt.subplot(122)
plt.title('Image with random jitter')
plt.imshow(random_jitter(sample_image[0]))

In [0]:
plt.figure(figsize=(12,5))
plt.subplot(121)
plt.title('Label')
plt.imshow(sample_label[0,:,:,:])

plt.subplot(122)
plt.title('Label with random jitter')
plt.imshow(random_jitter(sample_label[0]))

## Import and reuse the Pix2Pix models

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

![Cyclegan model](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cyclegan_model.png?raw=1)

In [0]:
generator_g = pix2pix.unet_generator(3, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(3, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [0]:
to_label = generator_g(sample_image)
to_image = generator_f(sample_label)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_image, to_label, sample_label, to_image]
title = ['Image', 'To label', 'Label', 'To image']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0])
  else:
    plt.imshow(imgs[i][0])
plt.show()

In [0]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real label?')
plt.imshow(discriminator_y(sample_label)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real image?')
plt.imshow(discriminator_x(sample_image)[0, ..., -1], cmap='RdBu_r')

plt.show()

## Loss functions

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input `x` and the target `y` pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.

The discriminator loss and the generator loss are similar to the ones used in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#define_the_loss_functions_and_the_optimizer).

In [0]:
LAMBDA = 10

In [0]:
loss_obj = tf.keras.losses.MeanSquaredError()
#loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [0]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the  original sentence.

In cycle consistency loss, 

* Image $X$ is passed via generator $G$ that yields generated image $\hat{Y}$.
* Generated image $\hat{Y}$ is passed via generator $F$ that yields cycled image $\hat{X}$.
* Mean absolute error is calculated between $X$ and $\hat{X}$.

$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$

$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$


![Cycle loss](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cycle_loss.png?raw=1)

In [0]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

As shown above, generator $G$ is responsible for translating image $X$ to image $Y$. Identity loss says that, if you fed image $Y$ to generator $G$, it should yield the real image $Y$ or something close to image $Y$.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$

In [0]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Initialize the optimizers for all the generators and the discriminators.

In [0]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

## Checkpoints

In [0]:
checkpoint_path = "drive/My Drive/Course/CS230/CycleGAN/Ching-Ting/checkpoints"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

## Training

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate. 

In [0]:
EPOCHS = 100

In [0]:
show_lst = list(image_train_ds.unbatch().take(10).as_numpy_iterator())
print(len(show_lst))
print(show_lst[0].shape)

In [0]:

# save_path = 'drive/My Drive/Course/CS230/CycleGAN/Ching-Ting/results'
# def save_image_list(model, show_lst, epoch=0, save_path=save_path):
#   for idx in range(len(show_lst)):
#     test_input = np.expand_dims(show_lst[idx], axis=0)
#     prediction = model(test_input)
    
#     plt.ioff()
#     fig = plt.figure(figsize=(12, 12))

#     display_list = [test_input[0,:,:,:], prediction[0,:,:,:]]
#     title = ['Input Image', 'Predicted Image']

#     for i in range(2):
#       plt.subplot(1, 2, i+1)
#       plt.title(title[i])
#       # getting the pixel values between [0, 1] to plot it.
#       plt.imshow(display_list[i])
#       plt.axis('off')
    
#     plt.savefig(f"{save_path}/im{idx}_epoch{epoch}.png")
#     plt.close(fig)


In [0]:
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i])
    plt.axis('off')
  plt.show()

Even though the training loop looks complicated, it consists of four basic steps:

* Get the predictions.
* Calculate the loss.
* Calculate the gradients using backpropagation.
* Apply the gradients to the optimizer.

In [0]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [0]:
## record loss for each epoch
Gen_G_Loss = np.zeros((EPOCHS, 1))
Gen_F_Loss = np.zeros((EPOCHS, 1))
Total_Cycle_Loss = np.zeros((EPOCHS, 1))
Total_Gen_G_Loss = np.zeros((EPOCHS, 1))
Total_Gen_F_Loss = np.zeros((EPOCHS, 1))
Disc_X_Loss = np.zeros((EPOCHS, 1))
Disc_Y_Loss = np.zeros((EPOCHS, 1))

In [0]:
def save_loss(generator_g, generator_f, discriminator_x, discriminator_y, sample_image, sample_label, epoch):
  # compute the loss for every epoch
  fake_label = generator_g(sample_image)
  cycled_image = generator_f(fake_label)
  fake_image = generator_f(sample_label)
  cycled_label = generator_g(fake_image)
  same_image = generator_f(sample_image)
  same_label = generator_g(sample_label)
  disc_real_image = discriminator_x(sample_image)
  disc_real_label = discriminator_y(sample_label)
  disc_fake_image = discriminator_x(fake_image)
  disc_fake_label = discriminator_y(fake_label)
  # calculate the loss
  Gen_G_Loss[epoch] = generator_loss(disc_fake_label)
  Gen_F_Loss[epoch] = generator_loss(disc_fake_image)
  Total_Cycle_Loss[epoch] = calc_cycle_loss(sample_image, cycled_image) + calc_cycle_loss(sample_label, cycled_label) 
  # Total generator loss = adversarial loss + cycle loss
  Total_Gen_G_Loss[epoch] = Gen_G_Loss[epoch] + Total_Cycle_Loss[epoch] + identity_loss(sample_label, same_label)
  Total_Gen_F_Loss[epoch] = Gen_F_Loss[epoch] + Total_Cycle_Loss[epoch] + identity_loss(sample_image, same_image)
  Disc_X_Loss[epoch] = discriminator_loss(disc_real_image, disc_fake_image)
  Disc_Y_Loss[epoch] = discriminator_loss(disc_real_label, disc_fake_label)

In [0]:
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((image_train_ds, label_train_ds)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)

  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_image)
  generate_images(generator_f, sample_label)
  # if(epoch % 10) == 0:
  #   save_image_list(generator_g, show_lst, epoch)
  #   save_image_list(generator_f, show_lst, epoch)
  
  # save_loss(generator_g, generator_f, discriminator_x, discriminator_y, sample_image, sample_label, epoch)

  if (epoch + 1) % 5 == 0:
    #save_image_list(generator_g, show_lst, epoch)
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

# New Section

In [0]:
# data path 
path_image = '/content/drive/My Drive/Course/CS230/CycleGAN/img_v1'
path_label = '/content/drive/My Drive/Course/CS230/CycleGAN/lbl_v1'

# data names
im_list = os.listdir(path_image)
la_list = os.listdir(path_label)

# data full path
image_list = [os.path.join(path_image, ele) for ele in im_list]
label_list = [os.path.join(path_label, ele) for ele in la_list]

# manual split
image_list_train = image_list[:24]
label_list_train = label_list[:60]
image_list_test = image_list[24:]
label_list_test = label_list[60:]

# convert to tensor
image_list_train_tf = tf.constant(image_list_train)
label_list_train_tf = tf.constant(label_list_train)
image_list_test_tf = tf.constant(image_list_test)
label_list_test_tf = tf.constant(label_list_test)

# print length
print(len(image_list_train))
print(len(label_list_train))
print(len(image_list_test))
print(len(label_list_test))

print(label_list_test)

In [0]:
image_list_train

In [0]:
label_list_train

In [0]:
image_train_ds = tf.data.Dataset.from_tensor_slices(image_list_train_tf)
label_train_ds = tf.data.Dataset.from_tensor_slices(label_list_train_tf)

image_train_ds = image_train_ds.map(process_path_images_train, num_parallel_calls=AUTOTUNE).cache().batch(BATCH_SIZE)
label_train_ds = label_train_ds.map(process_path_labels_train, num_parallel_calls=AUTOTUNE).cache().batch(BATCH_SIZE)

In [0]:
i = 0
for inp in image_train_ds:
  i += 1
  prediction = generator_g(inp)
  np.save("image"+str(i),inp)
  np.save("label"+str(i),prediction)

In [0]:
i = 0
for inp in label_train_ds:
  i += 1
  prediction = generator_f(inp)
  np.save("shape"+str(i),inp)
  np.save("convert"+str(i),prediction)

In [0]:
def converge1(img):
  
  row,col= img.shape
  for i in range(row):
    for j in range(col):

        if(img[i][j]) < 50:
          img[i][j] = 0
        else:
          img[i][j] = 255
  return img

In [0]:
import cv2

In [0]:
def normalization(img):
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  row,col= img.shape
  img = img * 255
  return img

In [0]:
for i in range(1,7):
  img = np.load('image'+str(i)+'.npy')
  lbl = np.load('label'+str(i)+'.npy')
  shape = np.load('shape'+str(i)+'.npy')
  convert = np.load('convert'+str(i)+'.npy')
  for j in range(4):
    label = converge1(normalization(lbl[j]))
    truth = normalization(shape[j])
    intersection = np.logical_and(label, truth)
    union = np.logical_or(label, truth)
    iou_score = np.sum(intersection) / np.sum(union)
    print(iou_score)

In [0]:

lbl = np.load('label6.npy')
label = converge1(normalization(lbl[2]))

shape = np.load('shape6.npy')
truth = normalization(shape[2])

intersection = np.logical_and(label, truth)
union = np.logical_or(label, truth)
iou_score = np.sum(intersection) / np.sum(union)
print(iou_score)

In [0]:
plt.imshow(label)

In [0]:
plt.imshow(truth)

In [0]:

plt.figure()
plt.plot(np.arange(0,EPOCHS),Gen_G_Loss, label='Gen_G_Loss')
plt.plot(np.arange(0,EPOCHS),Gen_F_Loss, label='Gen_F_Loss')
plt.legend()

plt.figure()
plt.plot(np.arange(0,EPOCHS),Total_Cycle_Loss, label='Total_Cycle_Loss')
plt.plot(np.arange(0,EPOCHS),Total_Gen_G_Loss, label='Total_Gen_G_Loss')
plt.plot(np.arange(0,EPOCHS),Total_Gen_F_Loss, label='Total_Gen_F_Loss')
plt.legend()

plt.figure()
plt.plot(np.arange(0,EPOCHS),Disc_X_Loss, label='Disc_X_Loss')
plt.plot(np.arange(0,EPOCHS),Disc_Y_Loss, label='Disc_Y_Loss')
plt.legend()


## Generate using test dataset

In [0]:
import cv2

In [0]:
cd '/content/drive/My Drive/Course/CS230/CycleGAN/Ching-Ting/results'

In [0]:
def normalization(img):
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  row,col= img.shape
  img = img *255
  return img

In [0]:
sample_image = next(iter(image_train_ds))
img = normalization(np.float32(sample_image[0]))
convert = normalization(np.float32(generator_g(sample_image)[0]))
cv2.imwrite("img8.png",img)
cv2.imwrite("lbl8.png",convert)

In [0]:
sample_label = next(iter(label_train_ds))
img = normalization(np.float32(sample_label[0]))
convert = normalization(np.float32(generator_f(sample_label)[0]))
cv2.imwrite("label6.png",img)
cv2.imwrite("SynImg6.png",convert)

In [0]:
# Run the trained model on the test dataset
for inp in image_train_ds.take(10):
  generate_images(generator_g, inp)

In [0]:
# Run the trained model on the test dataset
for inp in label_train_ds.take(5):
  generate_images(generator_f, inp)

## Next steps

This tutorial has shown how to implement CycleGAN starting from the generator and discriminator implemented in the [Pix2Pix](https://www.tensorflow.org/tutorials/generative/pix2pix) tutorial. As a next step, you could try using a different dataset from [TensorFlow Datasets](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

You could also train for a larger number of epochs to improve the results, or you could implement the modified ResNet generator used in the [paper](https://arxiv.org/abs/1703.10593) instead of the U-Net generator used here.

In [0]:
fnames = os.listdir(checkpoint_path)
print(fnames)

In [0]:
ckpt_load = tf.train.Checkpoint()
status = ckpt_load.restore(f"{checkpoint_path}/ckpt-52")

In [0]:
status.assert_consumed()