In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

In [None]:
dataset, metadata = tfds.load('cycle_gan/apple2orange',
                              with_info=True, as_supervised=True)

In [None]:
dataset

In [None]:
train_monet, train_photo = dataset['trainA'], dataset['trainB']
test_monet, test_photo = dataset['testA'], dataset['testB']

In [None]:
IMG_SIZE = 256
def format_image(image,label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image,(IMG_SIZE,IMG_SIZE))
    return image,label

In [None]:

BUFFER_SIZE = 1000
BATCH_SIZE = 1
train_monet = train_monet.map(format_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_photo = train_photo.map(format_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_monet = test_monet.map(format_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_photo = test_photo.map(format_image).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
train_apple

In [None]:
combined_dataset = tf.data.Dataset.zip((train_monet, train_photo)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
sample_apple = next(iter(train_monet))
sample_orange = next(iter(train_photo))

In [None]:
class ReflectionPad2d(tf.keras.layers.Layer):
    def __init__(self, padding, **kwargs):
        super(ReflectionPad2d, self).__init__(**kwargs)
        self.padding = [[0, 0], [padding, padding], [padding, padding], [0, 0]]

    def call(self, inputs, **kwargs):
        return tf.pad(inputs, self.padding, 'REFLECT')


class ResNetBlock(tf.keras.Model):
    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.padding1 = ReflectionPad2d(1)
        self.conv1 = tf.keras.layers.Conv2D(dim, (3, 3), padding='valid', use_bias=False)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.relu1 = tf.keras.layers.ReLU()

        self.padding2 = ReflectionPad2d(1)
        self.conv2 = tf.keras.layers.Conv2D(dim, (3, 3), padding='valid', use_bias=False)
        self.bn2 = tf.keras.layers.BatchNormalization()

    def call(self, inputs, training=None, mask=None):
        x = self.padding1(inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.padding2(x)
        x = self.conv2(x)
        x = self.bn2(x)
        outputs = inputs + x
        return outputs


In [None]:
def make_generator_model(n_blocks):
    model = tf.keras.Sequential()

    # Encoding
    model.add(ReflectionPad2d(3, input_shape=(256, 256, 3)))
    model.add(tf.keras.layers.Conv2D(64, (7, 7), strides=(1, 1), padding='valid', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    # Transformation
    for i in range(n_blocks):
        model.add(ResNetBlock(256))

    # Decoding
    model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.ReLU())

    model.add(ReflectionPad2d(3))
    model.add(tf.keras.layers.Conv2D(3, (7, 7), strides=(1, 1), padding='valid', activation='tanh'))

    return model


In [None]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(3, (4, 4), strides=(2, 2), padding='same', input_shape=(256, 256, 3)))
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Conv2D(256, (4, 4), strides=(2, 2), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Conv2D(512, (4, 4), strides=(1, 1), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.LeakyReLU(alpha=0.2))
    model.add(tf.keras.layers.Conv2D(1, (4, 4), strides=(1, 1), padding='same'))
    return model


In [None]:
generator_a2b = make_generator_model(9)
generator_b2a = make_generator_model(9)
discriminator_b = make_discriminator_model()
discriminator_a = make_discriminator_model()

In [None]:
losses = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def calc_gan_loss(prediction, is_real):
    if is_real:
        return losses(prediction, tf.ones_like(prediction))
    else:
        return losses(prediction, tf.zeros_like(prediction))

def calc_cycle_loss(reconstructed_images, real_images):
    return losses(reconstructed_images, real_images)

def calc_identity_loss(identity_images, real_images):
    return losses(identity_images, real_images)


In [None]:
@tf.function
def train_generator(images_a, images_b):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:
        # Use real B to generate B should be identical
        identity_a2b = generator_a2b(real_b, training=True)
        identity_b2a = generator_b2a(real_a, training=True)
        loss_identity_a2b = calc_identity_loss(identity_a2b, real_b)
        loss_identity_b2a = calc_identity_loss(identity_b2a, real_a)

        # Generator A2B tries to trick Discriminator B that the generated image is B
        loss_gan_gen_a2b = calc_gan_loss(discriminator_b(fake_a2b, training=True), True)
        # Generator B2A tries to trick Discriminator A that the generated image is A
        loss_gan_gen_b2a = calc_gan_loss(discriminator_a(fake_b2a, training=True), True)
        loss_cycle_a2b2a = calc_cycle_loss(recon_b2a, real_a)
        loss_cycle_b2a2b = calc_cycle_loss(recon_a2b, real_b)

        # Total generator loss
        loss_gen_total = loss_gan_gen_a2b + loss_gan_gen_b2a \
            + (loss_cycle_a2b2a + loss_cycle_b2a2b) * 10 \
            + (loss_identity_a2b + loss_identity_b2a) * 5

    trainable_variables = generator_a2b.trainable_variables + generator_b2a.trainable_variables
    gradient_gen = tape.gradient(loss_gen_total, trainable_variables)
    optimizer_gen.apply_gradients(zip(gradient_gen, trainable_variables))


In [None]:
@tf.function
def train_discriminator(images_a, images_b, fake_a2b, fake_b2a):
    real_a = images_a
    real_b = images_b
    with tf.GradientTape() as tape:

        # Discriminator A should classify real_a as A
        loss_gan_dis_a_real = calc_gan_loss(discriminator_a(real_a, training=True), True)
        # Discriminator A should classify generated fake_b2a as not A
        loss_gan_dis_a_fake = calc_gan_loss(discriminator_a(fake_b2a, training=True), False)

        # Discriminator B should classify real_b as B
        loss_gan_dis_b_real = calc_gan_loss(discriminator_b(real_b, training=True), True)
        # Discriminator B should classify generated fake_a2b as not B
        loss_gan_dis_b_fake = calc_gan_loss(discriminator_b(fake_a2b, training=True), False)

        # Total discriminator loss
        loss_dis_a = (loss_gan_dis_a_real + loss_gan_dis_a_fake) * 0.5
        loss_dis_b = (loss_gan_dis_b_real + loss_gan_dis_b_fake) * 0.5
        loss_dis_total = loss_dis_a + loss_dis_b

    trainable_variables = discriminator_a.trainable_variables + discriminator_b.trainable_variables
    gradient_dis = tape.gradient(loss_dis_total, trainable_variables)
    optimizer_dis.apply_gradients(zip(gradient_dis, trainable_variables))


In [None]:
def train_step(images_a, images_b, epoch, step):
    fake_a2b, fake_b2a, gen_loss_dict = train_generator(images_a, images_b)

    fake_b2a_from_pool = fake_pool_b2a.query(fake_b2a)
    fake_a2b_from_pool = fake_pool_a2b.query(fake_a2b)

    dis_loss_dict = train_discriminator(images_a, images_b, fake_a2b_from_pool, fake_b2a_from_pool)

def train(dataset, epochs):
    for epoch in range(epochs):
        for (step, batch) in enumerate(dataset):
            train_step(batch[0], batch[1], epoch, step)
