# CycleGAN painter!

This notebook inspired by Amy Jang's tutorial notebook.


# Introduction and Setup

This notebook utilizes a CycleGAN architecture to add Monet-style to photos. For this tutorial, we will be using the TFRecord dataset. Import the following packages and change the accelerator to TPU.

For more information, check out [TensorFlow](https://www.tensorflow.org/tutorials/generative/cyclegan) and [Keras](https://keras.io/examples/generative/cyclegan/) CycleGAN documentation pages.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa # addons ???
import math,random,cv2
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import PIL

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

# Load in the data

We want to keep our photo dataset and our Monet dataset separate. First, load in the filenames of the TFRecords.

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()
# google cloud storage path of the dataset. like some data root in linux something.

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))
# Monet's paint 

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
# What we want to change with the sytle like Monet paint.

# Datanum
https://www.kaggle.com/swepat/cyclegan-to-generate-monet-style-images

In [None]:
import re
def count_data_items(filenames):
    n = [int(re.compile(r'-([0-9]*)\.').search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)
print('number of monet tfrecord files :', len(MONET_FILENAMES))
print('number of photo tfrecord files:', len(PHOTO_FILENAMES))

In [None]:
print('n_monet_samples : ',n_monet_samples)
print('n_photo_samples : ',n_photo_samples)

# Visualization jpg_iamge


In [None]:
import glob
import os

# 要找出input下有什么文件，比较稳妥的方法就是一步步的os.path.list 比较好。否则直接看右边并不准确。比如 ： 现在就没有什么gan-getting-started 文件夹。
BASE_PATH = '../input/monet-gan-getting-started/'
monet_path = glob.glob(str(BASE_PATH+'monet_jpg/*jpg'))
photo_path = glob.glob(str(BASE_PATH+'photo_jpg/*jpg'))

# print(os.getcwd())
# print(os.listdir('../input/monet-gan-getting-started/monet_jpg'))



def visualization_images(imagepath, n_images, is_random=True, figsize=(16, 16)):
    plt.figure(figsize=figsize)

    rows = int(n_images ** 0.5)
    cols = math.ceil(n_images / rows)

    image_names = imagepath[:n_images]

    if is_random:
        image_names = random.sample(image_names,n_images)
        # 返回从image_names 重新排列的n_images.

    for i, image_name in enumerate(image_names):
        img = cv2.imread(image_name)
        # cv2 中图像现实的bgr形式，需要进行转换。
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.subplot(rows,cols,i+1) # 因为一共rows 行，cols列。
        plt.imshow(img)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

visualization_images(monet_path, 12, is_random=True, figsize=(23, 23))

In [None]:
visualization_images(photo_path, 12, is_random=True, figsize=(23, 23))

All the images for the competition are already sized to 256x256. As these images are RGB images, set the channel to 3. Additionally, we need to scale the images to a [-1, 1] scale. Because we are building a generative model, we don't need the labels or the image id so we'll only return the image from the TFRecord.

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1 # pixel value is in (-1,1)
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

Define the function to extract the image from the files.

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

Let's load in our datasets.

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

Let's  visualize a photo example and a Monet example.

In [None]:
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet')
plt.imshow(example_monet[0] * 0.5 + 0.5)

# Dataaugment

In [None]:
import random
import tensorflow as tf
import math
import tensorflow.keras.backend as K

def transform_rotation(image,height,rotation):
    '''random rotate the image'''
    dim = height
    xdim = height % 2
    rotation = rotation * tf.random.uniform([1], dtype='float32')
    rotation = math.pi / 180. * rotation
    # rotation matrix
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1], dtype='float32')
    zero = tf.constant([0], dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero,-s1,c1,zero,zero,zero,one],axis=0),shape=[3,3])

    # destination pixel indices :
    x = tf.repeat(tf.range(dim//2,-dim//2,-1), dim) # shape like : [111222333]
    y = tf.tile(tf.range(-dim//2,dim//2), [dim]) #shape like: [123123123]
    z = tf.ones([dim*dim], dtype='int32')
    idx = tf.stack([x,y,z]) # shape : 3 x (dim*dim*)

    # rotate destination pixels onto origin pixels :
    idx2 = K.dot(rotation_matrix, tf.cast(idx, dtype='float32'))
    idx2 = K.cast(idx2, dtype='int32')
    idx2 = K.clip(idx2,-dim//2+1+xdim, dim//2)

    # find original pixel values:
    # in image because the coordinate is different
    idx3 = tf.stack([dim//2-idx2[0,], dim//2-1+idx2[1,]])

    d = tf.gather_nd(image, tf.transpose(idx3))
    return tf.reshape(d, shape=[dim,dim,3])


def data_augument(image):
    p_rotate = tf.random.uniform([],0,1.0,dtype=tf.float32)
    # 旋转的概率
    p_spatial = tf.random.uniform([],0,1.0,dtype=tf.float32)
    # 空间增强
    p_crop = tf.random.uniform([],0,1.0,dtype=tf.float32)
    # 随机裁剪增强：
    if p_crop > 0.5 :
        # 随机扩大分辨率。
        temp_size = random.randint(260,290)
        image = tf.image.resize(image,[temp_size,temp_size])
        # 然后再转换回来。
        image = tf.image.random_crop(image,size=[256,256,3])
        if p_crop >0.9 :
            temp_size2 = random.randint(290,310)
            image = tf.image.resize(image,[temp_size2,temp_size2])
            image = tf.image.random_crop(image,size=[256,256,3])

    # 随机的旋转。
    if p_rotate > 0.75:
        image = tf.image.rot90(image,k=3)
    elif p_rotate > .5:
        image = tf.image.rot90(image,k=2)
    elif p_rotate > 0.25:
        image = tf.image.rot90(image,k=1)

    if p_rotate >=0.3:
        image = transform_rotation(image, height=256,rotation=45.)

    # 随机翻转：
    if p_spatial >0.6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial >0.85:
            image = tf.image.transpose(image)

    return image

    # 随机进行旋转。

    # tf.image.transpose 将image的高和宽进行转换。




### traindataset

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
def get_gan_dataset(monet_files, photo_files, augment=None,repeat=True,shuffle=True,batch_size=1):
    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)   
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds,photo_ds))
    
    return gan_ds




In [None]:
gan_dataset = get_gan_dataset(MONET_FILENAMES,PHOTO_FILENAMES,augment=data_augument,repeat=True,shuffle=True,batch_size=1)

In [None]:
# example_monet,example_photo = next(iter(gan_dataset)) 


# show augment iamges 

In [None]:
def view_images(example, nrows=1,ncols=5):
    ds_iter = iter(example)
    fig = plt.figure(figsize=(25, nrows*5.05))
    for i in range(ncols*nrows):
        image = next(ds_iter)
        image = image.numpy()
        ax = fig.add_subplot(nrows,ncols,i+1,xticks=[],yticks=[])
        ax.imshow(image[0]*0.5+.5)
        


In [None]:
view_images(monet_ds)

In [None]:
view_images(photo_ds)

In [None]:
def view_images_2(aug_example,nrows=2,ncols=5):
    ds_iter = iter(aug_example)
    fig = plt.figure(figsize=(25,nrows*5.05))
    for i in range((ncols)*(nrows//2)):
        monetimage,photoimage = next(ds_iter)
        monetimage = monetimage.numpy()
        photoimage = photoimage.numpy()
        ax = fig.add_subplot(nrows,ncols,i+1,xticks=[],yticks=[])
        ax.imshow(monetimage[0]*0.5+0.5)
        ax2 = fig.add_subplot(nrows,ncols,i+ncols+1,xticks=[],yticks=[])
        ax2.imshow(photoimage[0]*0.5+0.5)
        

In [None]:
view_images_2(gan_dataset)

# Build the generator

We'll be using a UNET architecture for our CycleGAN. To build our generator, let's first define our `downsample` and `upsample` methods.

The `downsample`, as the name suggests, reduces the 2D dimensions, the width and height, of the image by the stride. The stride is the length of the step the filter takes. Since the stride is 2, the filter is applied to every other pixel, hence reducing the weight and height by 2.

We'll be using an instance normalization instead of batch normalization. As the instance normalization is not standard in the TensorFlow API, we'll use the layer from TensorFlow Add-ons.

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

`Upsample` does the opposite of downsample and increases the dimensions of the of the image. `Conv2DTranspose` does basically the opposite of a `Conv2D` layer.

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

Let's build our generator!

The generator first downsamples the input image and then upsample while establishing long skip connections. Skip connections are a way to help bypass the vanishing gradient problem by concatenating the output of a layer to multiple layers instead of only one. Here we concatenate the output of the downsample layer to the upsample layer in a symmetrical fashion.
生成器将下采样层与对应的上采样曾进行连接，预防梯度消失。

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024) 因为输出是512，concate 之后就是 1024 
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x) # 将中间的layer的输出保存到skips中，之后进行连接。

    skips = reversed(skips[:-1]) #前边的最后一个和和后边的第一个连接，所以需要反向。最后一个不算。因为切片的特性。

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Build the discriminator

The discriminator takes in the input image and classifies it as real or fake (generated). Instead of outputing a single node, the discriminator outputs a smaller 2D image with higher pixel values indicating a real classification and lower values indicating a fake classification.**

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)
    # 为什么InstanceNormalizaion 初始化用gama_init 。而不用普通的呢？其实是一样的，但是tfa.layers.InstanceNormalization只是在一个通道内进行normalization.
    # batchnormalization 则是在一个batch的相应通道内进行normalization。
    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

Since our generators are not trained yet, the generated Monet-esque photo does not show what is expected at this point.

In [None]:
to_monet = monet_generator(example_photo)

plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

# Build the CycleGAN model

We will subclass a `tf.keras.Model` so that we can run `fit()` later to train our model. During the training step, the model transforms a photo to a Monet painting and then back to a photo. The difference between the original photo and the twice-transformed photo is the cycle-consistency loss. We want the original photo and the twice-transformed photo to be similar to one another.

The losses are defined in the next section.

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        # 传递一些新的loss_fn, 和 optimizer.
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        # 优化器要apply gradient，而gradient需要loss来求。
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # 然后用判别器对生成的和真实的进行评分。
            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # 三种loss 生成器loss，循环loss，辨别器的loss
            # 生成器loss 1：generator_loss: 从照片生成作品，从作品生成照片，2 identity_loss: 从照片生成照片，从作品生成作品，3 cycle_loss:真作品生成假作品，从真照片生成假照片。
            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # 最终得到生成作品，生成照片，判别作品，判别照片四个损失loss。
            # 分别对应两个生成器两个判别器，然后用他们来优化这四个子网络
            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # 判别器loss
            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet) #真实作品和照片生成的作品。
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo) #真实照片和从作品生成的假照片。

        # 计算梯度，然后应用梯度进行反向传播。
        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# Define loss functions

The discriminator loss function below compares real images to a matrix of 1s and fake images to a matrix of 0s. The perfect discriminator will output all 1s for real images and all 0s for fake images. The discriminator loss outputs the average of the real and generated loss.

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
        # 这里的real_loss 为什么这样设置？？
        # 因为如果是真的话，那么他们的空间cos距离接近于1，否则接近于0？
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        # 为什么？？？
        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

The generator wants to fool the discriminator into thinking the generated image is real. The perfect generator will have the discriminator output only 1s. Thus, it compares the generated image to a matrix of 1s to find the loss.

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated) #生成器生成的作品的距离需要接近于1
    

We want our original photo and the twice transformed photo to be similar to one another. Thus, we can calculate the cycle consistency loss be finding the average of their difference.

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA): #这里输入图像而不是图像的距离
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) # reduce_mean()没有指定axis，则取全部的平均值。

        return LAMBDA * loss1

The identity loss compares the image with its generator (i.e. photo with photo generator). If given a photo as input, we want it to generate the same image as the image was originally a photo. The identity loss compares the input with the output of the generator.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Train the CycleGAN

Let's compile our model. Since we used `tf.keras.Model` to build our CycleGAN, we can just ude the `fit` function to train our model.

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
tf.keras.utils.plot_model(cycle_gan_model, show_shapes=True, dpi=100)

In [None]:
class EarlyStopping_custom(tf.keras.callbacks.EarlyStopping):

    def __init__(self,monitor,mode,patience,
                 restore_best_weights=True):
        super(EarlyStopping_custom, self).__init__()
        self.restore_best_weights = restore_best_weights
        self.monitor = monitor,
        self.mode = mode,
        self.patience,
        
    def on_epoch_end(self, epoch, logs=None):
        current = self.get_monitor_value(logs)
        if current is None:
            return
        if self.monitor_op(current.mean() - self.min_delta, self.best):
            self.best = current.mean()
            self.wait = 0
            if self.restore_best_weights:
                self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                if self.restore_best_weights:
                    if self.verbose > 0:
                        print('Restoring model weights from the end of the best epoch.')
                    self.model.set_weights(self.best_weights)


early_stoping = EarlyStopping_custom(monitor='monet_disc_loss',
                                                 patience=5,
                                                 mode = 'min',
                                                 restore_best_weights=True)

In [None]:
BATCH_SIZE = 2
steps_per_epoch = (n_photo_samples//BATCH_SIZE)
gan_dataset = get_gan_dataset(MONET_FILENAMES,PHOTO_FILENAMES,augment=data_augument,repeat=True,shuffle=True,batch_size=BATCH_SIZE)


history = cycle_gan_model.fit(
    gan_dataset,
    epochs=31,
    steps_per_epoch = steps_per_epoch,
#     callbacks = [early_stoping],
)

In [None]:
np.asarray(history.history['photo_disc_loss']).shape   #为什么是这个形状？因为discriminator输出就是 batch 30 30 1
# 而BinaryCrossentropy，输出则是 batch 30 30 ，2表示两个epoch，4代表什么？？？0000

# Visualize our Monet-esque photos

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()





# Create submission file

In [None]:
import PIL 
! mkdir ../images 

In [None]:
# i = 1
# for img in photo_ds:
#     prediction = monet_generator(img, training=False)[0].numpy()
#     prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
#     im = PIL.Image.fromarray(prediction)
#     im.save("../images/" + str(i) + ".jpg")
#     i += 1
    
    
import PIL
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

In [None]:
predict_and_save(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, '../images/')

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

# print(f'number of generated samples :{len([name for name in os.listdir('/kaggle/images/') if os.path.isfile(os.path.join('/kaggle/images/', name))])}')