#CycleGAN Architecture
The CycleGAN setup builds on the Conditional GAN (CGAN) architecture since it is roughly two CGANs joined together. I.e. two sets of discriminators and generators that map an image from one domain to another, which requires a CycleGAN Python class of its own with methods to build generators and discriminators and run the training. 

Contrary to CGANs, they condition on a complete image, using an image  in the other domain as a label. This image is fed to discriminator A to evaluate whether it is real or not. It is thereafter fed to the Generator to translate it to domain B, then further evaluated by discriminator B to evaluate whether it is real in domain B, and finally translated back to domain A, to allow for measurment of the cyclic loss.


## Import dependencies

In [7]:
from __future__ import print_function, division
import scipy
from keras.datasets import mnist
from keras_contrib.layers import InstanceNormalization
#from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import (Input, Dense, Reshape, Flatten, Dropout, Concatenate, 
                          BatchNormalization, Activation, ZeroPadding2D, LeakyReLU)
#from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime, sys, os
import matplotlib.pyplot as plt
import numpy as np

##Get helper functions
Pip install the keras-contrib folder directly from GitHub for the InstanceNormalization and download the dataset.

In [1]:
#@title
!pip install git+https://www.github.com/keras-team/keras-contrib.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-c0defo75
  Running command git clone --filter=blob:none --quiet https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-c0defo75
  Resolved https://www.github.com/keras-team/keras-contrib.git to commit 3fc5ef709e061416f4bc8a92ca3750c824b5d2b0
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-py3-none-any.whl size=101076 sha256=e1735c904754b3f0059722ed8eb93a6200c9fdc75db9a10d0421c9171c660e7d
  Stored in directory: /tmp/pip-ephem-wheel-cache-yyrldvey/wheels/67/d2/f4/96ae3c3c62d1e05abfc8860ad0c1207794726d44ebbbb547f3
Successfully built keras-co

In [3]:
#@title { form-width: "10%" }
%%bash

FILE=apple2orange

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
mkdir ./datasets
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE

Archive:  ./datasets/apple2orange.zip
   creating: ./datasets/apple2orange/trainA/
  inflating: ./datasets/apple2orange/trainA/n07740461_6908.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_7635.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_586.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_9813.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_6835.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_2818.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_2918.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_1213.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_2476.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_13596.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_4093.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_14738.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_6477.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_11356.jpg  
  inf

for details.

--2023-01-23 11:01:09--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 78456409 (75M) [application/zip]
Saving to: ‘./datasets/apple2orange.zip’

     0K .......... .......... .......... .......... ..........  0%  324K 3m57s
    50K .......... .......... .......... .......... ..........  0%  327K 3m55s
   100K .......... .......... .......... .......... ..........  0%  289K 4m5s
   150K .......... .......... .......... .......... ..........  0%  228M 3m4s
   200K .......... .......... .......... .......... ..........  0%  332K 3m13s
   250K .......... .......... .......... .......... ..........  0% 25.8M 2m41s
   300K .......... .......... .......... .......... ..........  0% 38.9M 2m18s
   350K ..........

In [4]:
#!pip install scikit-image

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


##The DataLoader
Define a key data-holding object that loads the preprocessed data.

In [6]:
#@title
import scipy
from glob import glob
import numpy as np
import imageio
import skimage
from skimage.transform import resize

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    # Load data from disk based on the dataset name defined in the CycleGAN's initializer
    def load_data(self, domain, batch_size=1, is_testing=False):
        data_type = "train%s" % domain if not is_testing else "test%s" % domain
        path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))
        batch_images = np.random.choice(path, size=batch_size)

        imgs = []
        for img_path in batch_images:
            img = self.imread(img_path)
            if not is_testing:
                img = resize(img, self.img_res) # scipy.misc.imresize(img, self.img_res) deprecated
                if np.random.random() > 0.5:
                    img = np.fliplr(img)
            else:
                img = resize(img, self.img_res) # scipy.misc.imresize(img, self.img_res) deprecated
            imgs.append(img)
        imgs = np.array(imgs)/127.5 - 1.

        return imgs

    # Load batch during training
    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
        path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

        # Sample batches from each path list so the model sees all samples from both domains  
        self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
        total_samples = self.n_batches * batch_size
        path_A = np.random.choice(path_A, total_samples, replace=False)
        path_B = np.random.choice(path_B, total_samples, replace=False)

        for i in range(self.n_batches-1):
            batch_A = path_A[i*batch_size:(i+1)*batch_size]
            batch_B = path_B[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img_A, img_B in zip(batch_A, batch_B):
                img_A = self.imread(img_A)
                img_B = self.imread(img_B)
                img_A = resize(img_A, self.img_res) # scipy.misc.imresize(img_B, self.img_res) deprecated
                img_B = resize(img_B, self.img_res) # scipy.misc.imresize(img_B, self.img_res) deprecated
                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)
                imgs_A.append(img_A)
                imgs_B.append(img_B)
            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B

    # Helper function used in load_data/load_batch
    def imread(self, path):
        return imageio.imread(path, pilmode='RGB').astype(np.float) #scipy.misc.imread(path, mode='RGB').astype(np.float) deprecated

##The CycleGAN class


In [8]:
class CycleGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'apple2orange'
        # Use DataLoader object to import a preprocessed dataset
        self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_rows, self.img_cols))

        # Calculate output shape of D (the two Discriminators use the PatchGAN-based architecture)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 64

        # Loss weights
        self.lambda_cycle = 10.0                    # Cycle-consistency loss: ensures closeness of original and reconstructed image 
        self.lambda_id = 0.9 * self.lambda_cycle    # Identity loss: ensures 'necessary' changes to images (preserves color space)

        optimizer = Adam(0.0002, 0.5)
        
        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        self.d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

     # Construct Computational Graph of Generators
     """1 Creating the two Discriminators D_A and D_B and compiling them
        2 Creating the two Generators:
            a Instantiating G_AB and G_BA
            b Creating placeholders for the image input for both directions
            c Linking them both to an image in the other domain
            d Creating placeholders for the reconstructed images back in the original domain
            e Creating the identity loss constraint for both directions 
            f Not making the parameters of the Discriminators trainable for now
            g Compiling the two Generators"""

        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()
        
        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        
        # Translate images to the other domains
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        
        # Identity mapping of images (identity loss constraint for both directions)
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # Only train the generators for the combined model
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              #6 outputs: validities, reconstructions, identity losses for each A-B-A and B-A-B cycles
                              outputs=[valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                              loss_weights=[1, 1, self.lambda_cycle, self.lambda_cycle, self.lambda_id, self.lambda_id],
                              optimizer=optimizer)

##Static methods
Using a CycleGAN class that inherits from above CycleGAN class to define a class across multiple cells (for educational purposes - see reference below).

In [9]:
class CycleGAN(CycleGAN):  
      # Static method used to shrink the code size
      @staticmethod
      def conv2d(layer_input, filters, f_size=4, normalization=True):
        # Discriminator layer
        d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d
      
      @staticmethod
      # Transposed convolution function
      def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            # Layers used during upsampling
            u = UpSampling2D(size=2)(layer_input) # Uses nearest neighbors interpolation (not a learned parameter)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            # Create skip connection between output layer and corresponding layer from downsampling part
            u = Concatenate()([u, skip_input])
            return u

##The Generator

In [10]:
class CycleGAN(CycleGAN):
    # U-Net Generator
    def build_generator(self):
        # Assign image input to d0 (128×128×3)
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = self.conv2d(d0, self.gf) # Convolve from 128×128×3 layer to 64×64×32 layer
        d2 = self.conv2d(d1, self.gf * 2) # Convolve from 64×64×32 layer to 32×32×64 layer
        d3 = self.conv2d(d2, self.gf * 4) # Convolve from 32×32×64 layer to 16×16×128 layer
        d4 = self.conv2d(d3, self.gf * 8) # Convolve from 16×16×128 layer to 8×8×256 layer

        # Upsampling
        u1 = self.deconv2d(d4, d3, self.gf * 4) # Upsample d4 and create a skip connection between d3 and u1
        u2 = self.deconv2d(u1, d2, self.gf * 2) # Upsample u1 and create a skip connection between d2 and u2
        u3 = self.deconv2d(u2, d1, self.gf) # Upsample u2 and create a skip connection between d1 and u3

        u4 = UpSampling2D(size=2)(u3) # Regular upsampling to arrive at a 128×128×64 image
        output_img = Conv2D(self.channels, kernel_size=4, # Convolve to remove extra feature maps (128×128×3)
                            strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

##The Discriminator

In [11]:
class CycleGAN(CycleGAN):
    def build_discriminator(self):
      img = Input(shape=self.img_shape)

      d1 = self.conv2d(img, self.df, normalization=False) # Assign input image (128×128×3) to d1 (64×64×64)
      d2 = self.conv2d(d1, self.df * 2) # Assign d1 to d2 (32×32×128)
      d3 = self.conv2d(d2, self.df * 4) # Assign d2 to d3 (16×16×256)
      d4 = self.conv2d(d3, self.df * 8) # Assign d3 to d4 (8×8×256)

      validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4) # Flatten d4 to 8×8×1

      return Model(img, validity)

##The sampling function
Output translated images in "test"/evaluation mode and plot cycles:
1. A -> B´ -> A´
2. B -> A´ -> B´


In [19]:
class CycleGAN(CycleGAN):
      def sample_images(self, epoch, batch_i):
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
        
        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        plt.show()

##The training loop

In [20]:
class CycleGAN(CycleGAN):
      def train(self, epochs, batch_size=1, sample_interval=50):
        # Adversarial loss ensures realistic images (ground truth)
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)


        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total discriminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])
                # If at save interval => plot the generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

##Train the CycleGAN
Instantiate the CycleGAN and run the training method.

In [21]:
cycle_gan = CycleGAN()
cycle_gan.train(epochs=100, batch_size=64, sample_interval=10)

Output hidden; open in https://colab.research.google.com to view.

# References
For more info, see the book *GANs in Action* by Jakub Langr and Vladimir Bok.