# CycleGAN


This notebook assumes you are familiar with Pix2Pix, which you can learn about in the [Pix2Pix tutorial](https://www.tensorflow.org/tutorials/generative/pix2pix). The code for CycleGAN is similar, the main difference is an additional loss function, and the use of unpaired training data.

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain. 

This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

![Output Image 1](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_1.png?raw=1)
![Output Image 2](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_2.png?raw=1)

## Set up the input pipeline

Install the [tensorflow_examples](https://github.com/tensorflow/examples) package that enables importing of the generator and the discriminator.

In [0]:
!pip install git+https://github.com/tensorflow/examples.git
!pip install keras-rectified-adam

In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Input Pipeline

This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#load_the_dataset)

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [0]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

In [0]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [0]:
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

In [0]:
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

In [0]:
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image

In [0]:
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

In [0]:
def preprocess_image_test(image, label):
  image = normalize(image)
  return image

In [0]:
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

In [0]:
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

In [0]:
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

In [0]:
plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)

## Import and reuse the Pix2Pix models

Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.

The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).

There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here. 

* Generator `G` learns to transform image `X` to image `Y`. $(G: X -> Y)$
* Generator `F` learns to transform image `Y` to image `X`. $(F: Y -> X)$
* Discriminator `D_X` learns to differentiate between image `X` and generated image `X` (`F(Y)`).
* Discriminator `D_Y` learns to differentiate between image `Y` and generated image `Y` (`G(X)`).

![Cyclegan model](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cyclegan_model.png?raw=1)

In [0]:
%%writefile keras_octave_conv.py
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import *
import tensorflow as tf


__all__ = ['OctaveConv2D', 'octave_conv_2d']


class OctaveConv2D(Layer):
    """Octave convolutions.
    # Arguments
        octave: The division of the spatial dimensions by a power of 2.
        ratio_out: The ratio of filters for lower spatial resolution.
    # References
        - [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution]
          (https://arxiv.org/pdf/1904.05049.pdf)
    """

    def __init__(self,
                 filters,
                 kernel_size=(3,3),
                 octave=2,
                 ratio_out=0.125,
                 strides=(1, 1),
                 data_format=None,
                 dilation_rate=(1, 1),
                 activation=None,
                 use_bias=False,
                 use_transpose=False,
                 kernel_initializer='he_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(OctaveConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.octave = octave
        self.ratio_out = ratio_out
        self.strides = strides
        self.data_format = data_format
        self.dilation_rate = dilation_rate
        self.use_bias = use_bias
        self.use_transpose = use_transpose
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.activity_regularizer = activity_regularizer
        self.kernel_constraint = kernel_constraint
        self.bias_constraint = bias_constraint

        self.filters_low = int(filters * self.ratio_out)
        self.filters_high = filters - self.filters_low

        self.conv_high_to_high, self.conv_low_to_high = None, None
        if self.use_transpose:
          if self.filters_high > 0:
              self.conv_high_to_high = self._init_transconv(self.filters_high, name='{}-Trans-Conv2D-HH'.format(self.name))
              self.conv_low_to_high = self._init_transconv(self.filters_high, name='{}-Conv2D-LH'.format(self.name))
          self.conv_low_to_low, self.conv_high_to_low = None, None
          if self.filters_low > 0:
              self.conv_low_to_low = self._init_transconv(self.filters_low, name='{}-Trans-Conv2D-HL'.format(self.name))
              self.conv_high_to_low = self._init_transconv(self.filters_low, name='{}-Trans-Conv2D-LL'.format(self.name))
          self.pooling = AveragePooling2D(
              pool_size=self.octave,
              padding='valid',
              data_format=data_format,
              name='{}-AveragePooling2D'.format(self.name),
          )
          self.up_sampling = UpSampling2D(
              size=self.octave,
              data_format=data_format,
              name='{}-UpSampling2D'.format(self.name)
          )
        else:
          if self.filters_high > 0:
              self.conv_high_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-HH'.format(self.name))
              self.conv_low_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-LH'.format(self.name))
          self.conv_low_to_low, self.conv_high_to_low = None, None
          if self.filters_low > 0:
              self.conv_low_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-HL'.format(self.name))
              self.conv_high_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-LL'.format(self.name))
          self.pooling = AveragePooling2D(
              pool_size=self.octave,
              padding='valid',
              data_format=data_format,
              name='{}-AveragePooling2D'.format(self.name),
          )
          self.up_sampling = UpSampling2D(
              size=self.octave,
              data_format=data_format,
              name='{}-UpSampling2D'.format(self.name)
          )
    def _init_transconv(self, filters, name):
        return Conv2DTranspose(
            filters=filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding='same',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            name=name,
        )

    def _init_conv(self, filters, name):
        return Conv2D(
            filters=filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding='same',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            name=name,
        )

    def build(self, input_shape):
        if isinstance(input_shape, list):
            input_shape_high, input_shape_low = input_shape
        else:
            input_shape_high, input_shape_low = input_shape, None
        if self.data_format == 'channels_first':
            channel_axis, rows_axis, cols_axis = 1, 2, 3
        else:
            rows_axis, cols_axis, channel_axis = 1, 2, 3
        if input_shape_high[channel_axis] is None:
            raise ValueError('The channel dimension of the higher spatial inputs '
                             'should be defined. Found `None`.')
        if input_shape_low is not None and input_shape_low[channel_axis] is None:
            raise ValueError('The channel dimension of the lower spatial inputs '
                             'should be defined. Found `None`.')
        if input_shape_high[rows_axis] is not None and input_shape_high[rows_axis] % self.octave != 0 or \
           input_shape_high[cols_axis] is not None and input_shape_high[cols_axis] % self.octave != 0:
            raise ValueError('The rows and columns of the higher spatial inputs should be divisible by the octave. '
                             'Found {} and {}.'.format(input_shape_high, self.octave))
        if input_shape_low is None:
            self.conv_low_to_high, self.conv_low_to_low = None, None

        if self.conv_high_to_high is not None:
            with K.name_scope(self.conv_high_to_high.name):
                self.conv_high_to_high.build(input_shape_high)
        if self.conv_low_to_high is not None:
            with K.name_scope(self.conv_low_to_high.name):
                self.conv_low_to_high.build(input_shape_low)
        if self.conv_high_to_low is not None:
            with K.name_scope(self.conv_high_to_low.name):
                self.conv_high_to_low.build(input_shape_high)
        if self.conv_low_to_low is not None:
            with K.name_scope(self.conv_low_to_low.name):
                self.conv_low_to_low.build(input_shape_low)
        super(OctaveConv2D, self).build(input_shape)

    @property
    def trainable_weights(self):
        weights = []
        if self.conv_high_to_high is not None:
            weights += self.conv_high_to_high.trainable_weights
        if self.conv_low_to_high is not None:
            weights += self.conv_low_to_high.trainable_weights
        if self.conv_high_to_low is not None:
            weights += self.conv_high_to_low.trainable_weights
        if self.conv_low_to_low is not None:
            weights += self.conv_low_to_low.trainable_weights
        return weights

    @property
    def non_trainable_weights(self):
        weights = []
        if self.conv_high_to_high is not None:
            weights += self.conv_high_to_high.non_trainable_weights
        if self.conv_low_to_high is not None:
            weights += self.conv_low_to_high.non_trainable_weights
        if self.conv_high_to_low is not None:
            weights += self.conv_high_to_low.non_trainable_weights
        if self.conv_low_to_low is not None:
            weights += self.conv_low_to_low.non_trainable_weights
        return weights

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape_high, input_shape_low = input_shape
        else:
            input_shape_high, input_shape_low = input_shape, None

        output_shape_high = None
        if self.filters_high > 0:
            output_shape_high = self.conv_high_to_high.compute_output_shape(input_shape_high)
        output_shape_low = None
        if self.filters_low > 0:
            output_shape_low = self.conv_high_to_low.compute_output_shape(
                self.pooling.compute_output_shape(input_shape_high),
            )

        if self.filters_low == 0:
            return output_shape_high
        if self.filters_high == 0:
            return output_shape_low
        return [output_shape_high, output_shape_low]

    def call(self, inputs, **kwargs):
        if isinstance(inputs, list):
            inputs_high, inputs_low = inputs
        else:
            inputs_high, inputs_low = inputs, None

        outputs_high_to_high, outputs_low_to_high = 0.0, 0.0
        if self.use_transpose:
          if self.conv_high_to_high is not None:
              outputs_high_to_high = self.conv_high_to_high(inputs_high)
          if self.conv_low_to_high is not None:
              outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low))
          outputs_high = outputs_high_to_high + outputs_low_to_high

          outputs_low_to_low, outputs_high_to_low = 0.0, 0.0
          if self.conv_low_to_low is not None:
              outputs_low_to_low = self.conv_low_to_low(inputs_low)
          if self.conv_high_to_low is not None:
              outputs_high_to_low = self.pooling(self.conv_high_to_low(inputs_high))
          outputs_low = outputs_low_to_low + outputs_high_to_low

          if self.filters_low == 0:
              return outputs_high
          if self.filters_high == 0:
              return outputs_low
        else:
          if self.conv_high_to_high is not None:
              outputs_high_to_high = self.conv_high_to_high(inputs_high)
          if self.conv_low_to_high is not None:
              outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low))
          outputs_high = outputs_high_to_high + outputs_low_to_high

          outputs_low_to_low, outputs_high_to_low = 0.0, 0.0
          if self.conv_low_to_low is not None:
              outputs_low_to_low = self.conv_low_to_low(inputs_low)
          if self.conv_high_to_low is not None:
              outputs_high_to_low = self.conv_high_to_low(self.pooling(inputs_high))
          outputs_low = outputs_low_to_low + outputs_high_to_low

          if self.filters_low == 0:
              return outputs_high
          if self.filters_high == 0:
              return outputs_low
        return [outputs_high, outputs_low]

    def get_config(self):
        config = {
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'octave': self.octave,
            'ratio_out': self.ratio_out,
            'strides': self.strides,
            'data_format': self.data_format,
            'dilation_rate': self.dilation_rate,
            'use_bias': self.use_bias,
            'kernel_initializer': self.kernel_initializer,
            'bias_initializer': self.bias_initializer,
            'kernel_regularizer': self.kernel_regularizer,
            'bias_regularizer': self.bias_regularizer,
            'activity_regularizer': self.activity_regularizer,
            'kernel_constraint': self.kernel_constraint,
            'bias_constraint': self.bias_constraint
        }
        base_config = super(OctaveConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class instance_norm(tf.keras.layers.Layer):
  """Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""

  def __init__(self, epsilon=1e-5):
    super(instance_norm, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset

In [0]:
%%writefile model.py
# model code all in this cell

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math
import numpy as np
import tensorflow as tf
from sklearn.utils import class_weight

from keras_radam.training import RAdamOptimizer
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras_octave_conv import OctaveConv2D
from keras_octave_conv import instance_norm

      
def unet(pretrained_weights = None,input_size = (256,256,3)):
    inputs = Input(input_size)
    # downsampling for lower
    low = layers.AveragePooling2D(2)(inputs)
    high1, low1 = OctaveConv2D(64)([inputs,low])
    high1 = instance_norm()(high1)
    high1 = layers.Activation("relu")(high1)
    low1 = instance_norm()(low1)
    low1 = layers.Activation("relu")(low1)
    high1, low1 = OctaveConv2D(64)([high1, low1])
    high1 = instance_norm()(high1)
    high1 = layers.Activation("relu")(high1)
    low1 = instance_norm()(low1)
    low1 = layers.Activation("relu")(low1)
    pool1high = layers.MaxPooling2D(2)(high1)
    pool1low = layers.MaxPooling2D(2)(low1)
    
    high2, low2 = OctaveConv2D(128)([pool1high,pool1low])
    high2 = instance_norm()(high2)
    high2 = layers.Activation("relu")(high2)
    low2 = instance_norm()(low2)
    low2 = layers.Activation("relu")(low2)
    high2, low2 = OctaveConv2D(128)([high2, low2])
    high2 = instance_norm()(high2)
    high2 = layers.Activation("relu")(high2)
    low2 = instance_norm()(low2)
    low2 = layers.Activation("relu")(low2)
    pool2high = layers.MaxPooling2D(2)(high2)
    pool2low = layers.MaxPooling2D(2)(low2)
    
    high3, low3 = OctaveConv2D(256)([pool2high,pool2low])
    high3 = instance_norm()(high3)
    high3 = layers.Activation("relu")(high3)
    low3 = instance_norm()(low3)
    low3 = layers.Activation("relu")(low3)
    high3, low3 = OctaveConv2D(256)([high3, low3])
    high3 = instance_norm()(high3)
    high3 = layers.Activation("relu")(high3)
    low3 = instance_norm()(low3)
    low3 = layers.Activation("relu")(low3)
    pool3high = layers.MaxPooling2D(2)(high3)
    pool3low = layers.MaxPooling2D(2)(low3)
    
    high4, low4 = OctaveConv2D(512)([pool3high,pool3low])
    high4 = instance_norm()(high4)
    high4 = layers.Activation("relu")(high4)
    low4 = instance_norm()(low4)
    low4 = layers.Activation("relu")(low4)
    high4, low4 = OctaveConv2D(512)([high4, low4])
    high4 = instance_norm()(high4)
    high4 = layers.Activation("relu")(high4)
    low4 = instance_norm()(low4)
    low4 = layers.Activation("relu")(low4)
    pool4high = layers.MaxPooling2D(2)(high4)
    pool4low = layers.MaxPooling2D(2)(low4)

    high5, low5 = OctaveConv2D(1024)([pool4high, pool4low])
    high5 = instance_norm()(high5)
    high5 = layers.Activation("relu")(high5)
    low5 = instance_norm()(low5)
    low5 = layers.Activation("relu")(low5)
    high5 = Dropout(0.4)(high5)
    low5 = Dropout(0.4)(low5)
    high5, low5 = OctaveConv2D(1024)([high5, low5])
    high5 = instance_norm()(high5)
    high5 = layers.Activation("relu")(high5)
    low5 = instance_norm()(low5)
    low5 = layers.Activation("relu")(low5)
    high5 = Dropout(0.4)(high5)
    low5 = Dropout(0.4)(low5)
    
    uphigh6, uplow6 = OctaveConv2D(512, use_transpose=True, strides=(2,2))([high5,low5])
    uphigh6 = instance_norm()(uphigh6)
    uphigh6 = layers.Activation("relu")(uphigh6)
    uplow6 = instance_norm()(uplow6)
    uplow6 = layers.Activation("relu")(uplow6)
    merge6high = concatenate([high4,uphigh6], axis = 3)
    merge6low = concatenate([low4,uplow6], axis = 3)
    high6, low6 = OctaveConv2D(512)([merge6high,merge6low])
    high6 = instance_norm()(high6)
    high6 = layers.Activation("relu")(high6)
    low6 = instance_norm()(low6)
    low6 = layers.Activation("relu")(low6)
    high6, low6 = OctaveConv2D(512)([high6, low6])
    high6 = instance_norm()(high6)
    high6 = layers.Activation("relu")(high6)
    low6 = instance_norm()(low6)
    low6 = layers.Activation("relu")(low6)


    uphigh7, uplow7 = OctaveConv2D(256, use_transpose=True, strides=(2,2))([high6, low6])
    uphigh7 = instance_norm()(uphigh7)
    uphigh7 = layers.Activation("relu")(uphigh7)
    uplow7 = instance_norm()(uplow7)
    uplow7 = layers.Activation("relu")(uplow7)
    merge7high = concatenate([high3,uphigh7], axis = 3)
    merge7low = concatenate([low3,uplow7], axis = 3)
    high7, low7 = OctaveConv2D(256)([merge7high, merge7low])
    high7 = instance_norm()(high7)
    high7 = layers.Activation("relu")(high7)
    low7 = instance_norm()(low7)
    low7 = layers.Activation("relu")(low7)
    high7, low7 = OctaveConv2D(256)([high7, low7])
    high7 = instance_norm()(high7)
    high7 = layers.Activation("relu")(high7)
    low7 = instance_norm()(low7)
    low7 = layers.Activation("relu")(low7)

    uphigh8, uplow8 = OctaveConv2D(128, use_transpose=True, strides=(2,2))([high7, low7])
    uphigh8 = instance_norm()(uphigh8)
    uphigh8 = layers.Activation("relu")(uphigh8)
    uplow8 = instance_norm()(uplow8)
    uplow8 = layers.Activation("relu")(uplow8)
    merge8high = concatenate([high2,uphigh8], axis = 3)
    merge8low = concatenate([low2,uplow8], axis = 3)
    high8, low8 = OctaveConv2D(128)([merge8high, merge8low])
    high8 = instance_norm()(high8)
    high8 = layers.Activation("relu")(high8)
    low8 = instance_norm()(low8)
    low8 = layers.Activation("relu")(low8)
    high8, low8 = OctaveConv2D(128)([high8, low8])
    high8 = instance_norm()(high8)
    high8 = layers.Activation("relu")(high8)
    low8 = instance_norm()(low8)
    low8 = layers.Activation("relu")(low8)

    uphigh9, uplow9 = OctaveConv2D(64, use_transpose=True, strides=(2,2))([high8, low8])
    uphigh9 = instance_norm()(uphigh9)
    uphigh9 = layers.Activation("relu")(uphigh9)
    uplow9 = instance_norm()(uplow9)
    uplow9 = layers.Activation("relu")(uplow9)
    merge9high = concatenate([high1,uphigh9], axis = 3)
    merge9low = concatenate([low1,uplow9], axis = 3)
    high9, low9 = OctaveConv2D(64)([merge9high, merge9low])
    high9 = instance_norm()(high9)
    high9 = layers.Activation("relu")(high9)
    low9 = instance_norm()(low9)
    low9 = layers.Activation("relu")(low9)
    high9, low9 = OctaveConv2D(64)([high9, low9])
    high9 = instance_norm()(high9)
    high9 = layers.Activation("relu")(high9)
    low9 = instance_norm()(low9)
    low9 = layers.Activation("relu")(low9)
    conv9 = OctaveConv2D(32, ratio_out=0.0)([high9, low9])
    conv9 = layers.Activation("sigmoid")(conv9)
    conv10 = layers.Conv2D(3, 1, activation = 'tanh')(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    
    # model.summary()
    
    # model.compile(optimizer = RAdamOptimizer(learning_rate=1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])

    if(pretrained_weights):
    	model.load_weights(pretrained_weights)

    return model

In [0]:
from model import *
OUTPUT_CHANNELS = 3

generator_g = unet()
generator_f = unet()

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [0]:
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [0]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

## Loss functions

In CycleGAN, there is no paired data to train on, hence there is no guarantee that the input `x` and the target `y` pair are meaningful during training. Thus in order to enforce that the network learns the correct mapping, the authors propose the cycle consistency loss.

The discriminator loss and the generator loss are similar to the ones used in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#define_the_loss_functions_and_the_optimizer).

In [0]:
LAMBDA = 10

In [0]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [0]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

Cycle consistency means the result should be close to the original input. For example, if one translates a sentence from English to French, and then translates it back from French to English, then the resulting sentence should be the same as the  original sentence.

In cycle consistency loss, 

* Image $X$ is passed via generator $G$ that yields generated image $\hat{Y}$.
* Generated image $\hat{Y}$ is passed via generator $F$ that yields cycled image $\hat{X}$.
* Mean absolute error is calculated between $X$ and $\hat{X}$.

$$forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}$$

$$backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}$$


![Cycle loss](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cycle_loss.png?raw=1)

In [0]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

As shown above, generator $G$ is responsible for translating image $X$ to image $Y$. Identity loss says that, if you fed image $Y$ to generator $G$, it should yield the real image $Y$ or something close to image $Y$.

$$Identity\ loss = |G(Y) - Y| + |F(X) - X|$$

In [0]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

Initialize the optimizers for all the generators and the discriminators.

In [0]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [0]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

## Training

Note: This example model is trained for fewer epochs (40) than the paper (200) to keep training time reasonable for this tutorial. Predictions may be less accurate. 

In [0]:
EPOCHS = 20

In [0]:
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Even though the training loop looks complicated, it consists of four basic steps:

* Get the predictions.
* Calculate the loss.
* Calculate the gradients using backpropagation.
* Apply the gradients to the optimizer.

In [0]:
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [0]:
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n+=1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

## Generate using test dataset

In [0]:
# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)