In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import re
import os
import math
import random
import cv2

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

# 读取数据

读取jpg文件

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')
fn_monet = tf.io.gfile.glob(str(GCS_PATH + '/monet_jpg/*.jpg'))
fn_photo = tf.io.gfile.glob(str(GCS_PATH + '/photo_jpg/*.jpg'))

In [None]:
BATCH_SIZE =  4


def parse_function(filename):
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = (tf.cast(image,tf.float32)/ 127.5) - 1
    image = tf.reshape(image, [256, 256,3])
    return image

def data_augment(image):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286])
        image = tf.image.random_crop(image, size=[256, 256, 3])
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3)
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2)
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1)
        
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

num_parallel_calls=tf.data.experimental.AUTOTUNE
def getSet(filenames):
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.shuffle(len(filenames))
    dataset = dataset.map(parse_function, num_parallel_calls)
    dataset = dataset.map(data_augment, num_parallel_calls)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.cache()
    dataset = dataset.prefetch(num_parallel_calls)
    return dataset

monet_ds=getSet(fn_monet)
photo_ds=getSet(fn_photo)

# 生成器的实现

生成器使用U-Net来实现。U-Net主要被用于图像分割。由于是个图像到图像的网络，因此在CycleGAN中可以被用于生成器

U-Net首先对图像进行下采样，方法是使用步长为2的$4\times 4$**卷积**，使用padding="same"，将使得卷积后图片的尺寸变为“原尺寸/步长“的向上取整（如果padding="valid"，则图片的尺寸会变为“(原尺寸-卷积核尺寸+1)/步长“的向上取整）。第一次卷积，使用64个卷积核将3通道变成64通道，之后每一次使用的卷积核数目是上一次的两倍，从而将通道数增大一倍，最大增大到512通道。**保存下采样的每一步的结果**，在之后会用到。

等到把图片大小卷积到$1\times 1$后，转为对图片进行升采样，方法是使用步长为2的$4\times 4$**逆卷积**，使用padding="same"，将使得卷积后图片的尺寸变为“原尺寸\*步长“（如果padding="valid"，则图片的尺寸会变为“原尺寸\*步长-卷积核尺寸+1“）最后一次卷积，使用3个卷积核将64通道变成3通道，之前每一次使用的卷积核数目是前一次的1/2倍，从而将通道数减少一半，从512通道开始减少。上采样过程中，每次会沿通道维度拼接下采样过程中对应步的结果。

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer,use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    #下采样层列表
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64) 256/2向上取整（因为padding是same，否则(256-4+1)/2向上取整）
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]
    #上采样层列表
    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024) 1*2（因为padding是same，否则(1*2-4+1)），此外还拼接了一个张量，因此通道数增加
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,strides=2,padding='same',kernel_initializer=initializer,activation='tanh') # (bs, 256, 256, 3)，最后一层，重新变成原来大小

    x = inputs

    #进行下采样
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])#最下层的在最前

    #进行上采样，并且进行跨接
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

In [None]:
tf.keras.utils.plot_model(Generator(), show_shapes=True, dpi=64)

# 鉴别器的实现

鉴别器使用的是普通的下采样卷积神经网络，下采样3次后，输出一个$30\times30$的矩阵作为结果，具有指示真实分类的较高像素值和指示假分类的较低值

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,kernel_initializer=initializer,use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)该鉴别器不输出单个节点，而是输出较小的2D图像，其具有指示真实分类的较高像素值和指示假分类的较低值

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
tf.keras.utils.plot_model(Discriminator(), show_shapes=True, dpi=64)

# CycleGAN

CycleGAN包含两个生成器$G$和$F$，对应两个鉴别器$D_X$和$D_Y$

* 生成器：$G(x)$，$x$为分布$p_x$中取样的样本，目标是其生成的样本和分布$p_y$中取样的样本$y$尽可能接近
* 鉴别器：$D_Y(y)$，目标是尽可能区分$G(x)$和$y$，输出属于$G(p_x)$而不属于$p_y$的概率

**对抗性损失**1：$L_{GAN}(G,D_Y)=E_{y\sim p_y}(logD_Y(y))+E_{x\sim p_x}(logD_Y(G(x)))$

* 生成器：$F(y)$，$y$为分布$p_y$中取样的样本，目标是其生成的样本和分布$p_x$中取样的样本$x$尽可能接近
* 鉴别器：$D_X(x)$，目标是尽可能区分$ F(y)$和$x$，输出属于$F(p_y)$而不属于$p_x$的概率

**对抗性损失**2：$L_{GAN}(F,D_X)=E_{x\sim p_x}(logD_X(x))+E_{y\sim p_y}(logD_X(F(y)))$

尽管上述的对抗性损失能够让生成器$G$和生成器$F$学习到$p_x$和$p_y$，但是却没有保证从$x$得到$G(x)$时图像的内容不变，因为$G(x)$只需要符合$p_y$即可，并没有对其施加约束，所以$x$到$G(x)$包含很多种可能的映射

为此，使用**循环一致性损失**来作为约束，使得$G$生成的$G(x)$在内容上仍然能和$x$保持一致

循环一致性损失：$L_{cyc}(G,F)=E_{x\sim p_x}[||F(G(x))-x||_1]+E_{y\sim p_y}[||G(F(y))-y||_1]$

总体损失：$L(G,F,D_X,D_Y)=L_{GAN}(G,D_Y)+L_{GAN}(F,D_X)+\lambda L_{cyc}(G,F)$

以上是CycleGAN论文中的内容，实现的时候我们实际上是这么做的：

* 对鉴别器$Y $而言，要最小化的损失函数为：$D_Y(y)$接近0的程度+$D_Y(G(x))$接近1的程度
* 对鉴别器$X $而言，要最小化的损失函数为：$D_X(x)$接近0的程度+$D_X(F(y))$接近1的程度
* 循环一致性损失为：$F(G(x))$与$x$的距离（像素平均）+$G(F(y))$与$y$的距离（像素平均）
* 对生成器$G$而言，要最小化的损失函数为：$D_Y(G(x))$接近0的程度+$\lambda$*循环一致性损失
* 对生成器$F$而言，要最小化的损失函数为：$D_X(F(y))$接近0的程度+$\lambda$*循环一致性损失

我们让鉴别器输出一个$30\times30$的矩阵作为结果，并**与全0矩阵或全1矩阵计算交叉熵**，以衡量输出结果接近0或接近1的程度

本代码还包括一个自我一致性损失，这是论文中没有的部分，计算的是$G(y)$与$y$的距离（像素平均）+$F(x)$与$x$的距离（像素平均），同样乘以$\lambda$后加到了生成器的损失函数上

In [None]:
class CycleGan(keras.Model):
    def __init__(self,monet_generator,photo_generator,monet_discriminator,photo_discriminator,lambda_cycle=15):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(self,m_gen_optimizer,p_gen_optimizer,m_disc_optimizer,p_disc_optimizer,gen_loss_fn,disc_loss_fn,cycle_loss_fn,identity_loss_fn):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        #real_monet为y，real_photo为x，m_gen为G，p_gen为F，p_disc为DX，m_disc为DY
        
        with tf.GradientTape(persistent=True) as tape:
            fake_monet = self.m_gen(real_photo, training=True)#G(x)
            cycled_photo = self.p_gen(fake_monet, training=True)#F(G(x))

            fake_photo = self.p_gen(real_monet, training=True)#F(y)
            cycled_monet = self.m_gen(fake_photo, training=True)#G(F(y))

            same_monet = self.m_gen(real_monet, training=True)#G(y)
            same_photo = self.p_gen(real_photo, training=True)#F(x)

            disc_real_monet = self.m_disc(real_monet, training=True)#DY(y)
            disc_real_photo = self.p_disc(real_photo, training=True)#DX(x)

            disc_fake_monet = self.m_disc(fake_monet, training=True)#DY(G(x))
            disc_fake_photo = self.p_disc(fake_photo, training=True)#DX(F(y))

            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)#用于训练生成器的损失函数（基本部分）
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)#用于训练生成器的损失函数（基本部分）

            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)#循环一致性损失，乘以lambda

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)#用于训练生成器的总损失函数
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)#用于训练生成器的总损失函数

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)#用于训练鉴别器的损失函数
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)#用于训练鉴别器的损失函数

        # 损失函数对网络当中的参数求梯度（之前由网络输出结果时，设置了training为True，因此可以对网络中的参数求梯度）
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,self.m_gen.trainable_variables)#生成器的总损失函数对生成器的所有参数求梯度
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,self.p_gen.trainable_variables)#生成器的总损失函数对生成器的所有参数求梯度
        monet_discriminator_gradients = tape.gradient(monet_disc_loss,self.m_disc.trainable_variables)#鉴别器的损失函数对鉴别器的所有参数求梯度
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,self.p_disc.trainable_variables)#鉴别器的损失函数对鉴别器的所有参数求梯度

        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,self.m_gen.trainable_variables))#优化器向前迭代一步
        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,self.p_gen.trainable_variables))#优化器向前迭代一步
        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,self.m_disc.trainable_variables))#优化器向前迭代一步
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,self.p_disc.trainable_variables))#优化器向前迭代一步
        
        return {"monet_gen_loss": total_monet_gen_loss,"photo_gen_loss": total_photo_gen_loss,"monet_disc_loss": monet_disc_loss,"photo_disc_loss": photo_disc_loss}

# 损失函数

用于训练鉴别器的损失函数

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)#对于真实样本，通过鉴别器后与全1矩阵计算交叉熵，使得D对真实数据输出尽可能接近1
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)#对于生成样本，通过鉴别器后与全0矩阵计算交叉熵，使得D对生成数据输出尽可能接近0
        total_disc_loss = real_loss + generated_loss
        return total_disc_loss * 0.5

用于训练生成器的损失函数（基本部分）

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)#对于生成样本，通过鉴别器后与全1矩阵计算交叉熵，使得D对生成数据输出尽可能接近1

循环一致性损失

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))#循环一致性损失
        return LAMBDA * loss1

自我一致损失（论文中没有的部分）

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# 训练

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    monet_generator=Generator()
    photo_generator=Generator()
    monet_discriminator=Discriminator()
    photo_discriminator=Discriminator()
with strategy.scope():
    cycle_gan_model = CycleGan(monet_generator,photo_generator,monet_discriminator,photo_discriminator)
    cycle_gan_model.compile(
            m_gen_optimizer = monet_generator_optimizer,
            p_gen_optimizer = photo_generator_optimizer,
            m_disc_optimizer = monet_discriminator_optimizer,
            p_disc_optimizer = photo_discriminator_optimizer,
            gen_loss_fn = generator_loss,
            disc_loss_fn = discriminator_loss,
            cycle_loss_fn = calc_cycle_loss,
            identity_loss_fn = identity_loss
        )

In [None]:
steps_per_epoch=(max(len(fn_monet), len(fn_photo)))//BATCH_SIZE
cycle_gan_model.fit(tf.data.Dataset.zip((monet_ds, photo_ds)),epochs=25,steps_per_epoch=steps_per_epoch)

# 可视化结果

In [None]:
testset = tf.data.Dataset.from_tensor_slices(fn_photo)
testset = testset.shuffle(len(fn_photo))
testset = testset.map(parse_function, num_parallel_calls)
testset=testset.batch(1)
testset = testset.prefetch(num_parallel_calls)


_, ax = plt.subplots(5, 2, figsize=(20, 20))
for i, img in enumerate(testset.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

# 提交

In [None]:
import PIL

def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(output_path+"/"+str(i)+'.jpg')
        i += 1

In [None]:
import os
os.makedirs('../images/') # Create folder to save generated images

predict_and_save(testset, monet_generator, '../images/')

In [None]:
import shutil
shutil.make_archive('/kaggle/working/images/', 'zip', '../images')

print(f"Generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}")