In [None]:
import os
import re
import shutil
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from kaggle_datasets import KaggleDatasets
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np

from scipy import linalg

In [None]:
alpha = 0.2
shape = [256, 256, 3]
Lambda_cycle = 15
Lambda_identity = 10
BATCH_SIZE = 30
#EPOCHS = 1
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('gan-getting-started')

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/monet*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/photo*.tfrec'))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_monet_samples = count_data_items(MONET_FILENAMES)
n_photo_samples = count_data_items(PHOTO_FILENAMES)   #获取图片的数量

def data_augment(image):   #数据增强
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    if p_crop > .5:
        image = tf.image.resize(image, [286, 286]) #resizing to 286 x 286 x 3
        image = tf.image.random_crop(image, size=[256, 256, 3]) # randomly cropping to 256 x 256 x 3
        if p_crop > .9:
            image = tf.image.resize(image, [300, 300])
            image = tf.image.random_crop(image, size=[256, 256, 3])
    
    if p_rotate > .9:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .7:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
        ## random mirroring
    if p_spatial > .6:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        if p_spatial > .9:
            image = tf.image.transpose(image)
    
    return image

IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):   #数据集解码
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTOTUNE)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTOTUNE)
        
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
    monet_ds = monet_ds.cache()
    photo_ds = photo_ds.cache()
    monet_ds = monet_ds.prefetch(AUTOTUNE)
    photo_ds = photo_ds.prefetch(AUTOTUNE)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))  #经过数据增强后将数据集进行组合
    
    return gan_ds,photo_ds,monet_ds

In [None]:
#获得完整的数据集
full_dataset,my_photo_ds,my_monet_ds = get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment=data_augment, repeat=True, shuffle=True, batch_size=BATCH_SIZE)

#这段代码用于可视化数据集，调试上方的程序
example_monet , example_photo = next(iter(full_dataset))
plt.subplot(121)
plt.title('Real photo')
plt.imshow(example_photo[2] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet painting')
plt.imshow(example_monet[2]* 0.5 + 0.5)

In [None]:
fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)
monet_ds_fid = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

In [None]:
with strategy.scope():
    #inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)
    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)

    mix3  = inception_model.get_layer("mixed3").output
    f0 = tf.keras.layers.GlobalMaxPooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False
    
def calculate_activation_statistics_mod(images,fid_model):
        act=fid_model.predict(images)
        mu = np.mean(act, axis=0)
        sigma = np.cov(act, rowvar=False)
        return mu, sigma
myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(monet_ds_fid,inception_model)

In [None]:
print(myFID_mu2.shape,myFID_sigma2.shape)

In [None]:
def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

        # product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = f'fid calculation produces singular product; adding {fid_epsilon} to diagonal of cov estimates'
            warnings.warn(msg)
            offset = np.eye(sigma1.shape[0]) * fid_epsilon
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
            
        # numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError(f'Imaginary component {m}')
            covmean = covmean.real
        tr_covmean = np.trace(covmean)
        return (mu1 - mu2).dot(mu1 - mu2) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
            with strategy.scope():
                inp = layers.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
            mu1, sigma1 = calculate_activation_statistics_mod(images,fid_model)

            fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)

            return fid_value

In [None]:
OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential() 
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',kernel_initializer=initializer, use_bias=False))
    #这里没有做池化、Dropout，通过卷积减少图像体积，与VGG会有点不一样
    if apply_instancenorm: #是否进行标准化（InstanceNormalization优于batchNormalization）
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    
    result.add(layers.LeakyReLU()) #采用LeakyReLU激活，参数并没有设置

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer, use_bias=False))
    #基本上不用偏置
    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    
    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())  #用ReLU激活

    return result #注意：result是个model,是可以通过model(x)来作用于数据的

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]#首先下采样（这里层次可以添加（调整downSample的stride）！！）（参考VGG可是不止这么少的，不过down+up倒是足够）

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]#然后上采样
    #上面只是建立模型，下面是将模型用到数据集上
    
    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,strides=2,padding='same',
                                  kernel_initializer=initializer,activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)   #实现下采样

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip]) #这里是提高质量的核心，通过连接将当前上采样层和下采样层连接，即保留了卷积的特征，
        #又能在下采样的基础上进行上采样

    x = last(x)  #最后一层通过上采样生成[255,255,3]的假图

    return keras.Model(inputs=inputs, outputs=x) #返回生成器模型

In [None]:
generator_g = Generator()
tf.keras.utils.plot_model(generator_g, show_shapes=True, dpi=64)

photo = (example_photo[0,...] * 0.5 + 0.5)
example_gen_output_y = generator_g(photo[tf.newaxis,...], training=False)

plt.subplot(1,2,1)
plt.imshow(photo, vmin=0, vmax=255) 

plt.subplot(1,2,2)
plt.imshow(example_gen_output_y[0]) 

plt.show()

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    #层数上比较少，后面基本通过补零再卷积来提取特征
    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
discriminator_y = Discriminator()
tf.keras.utils.plot_model(discriminator_y, show_shapes=True, dpi=64) 

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
photo = (example_photo[0,...] * 0.5 + 0.5)
monet = (example_monet[0,...] * 0.5 + 0.5)

# From photo we generate Monet (fake) and regenerate the photo (cycle) again
example_gen_output_monet_fake = monet_generator(photo[tf.newaxis,...], training=False)
example_gen_output_photo_cycle =  photo_generator(example_gen_output_monet_fake, training=False)

# We run the discriminator for Monet (fake)
example_disc_out_monet = monet_discriminator(example_gen_output_monet_fake, training=False)


# From Monet we generate photo (fake) and regenerate Monet (cycle) again
example_gen_output_photo_fake =  photo_generator(monet[tf.newaxis,...], training=False)
example_gen_output_monet_cycle = monet_generator(example_gen_output_photo_fake, training=False)

# We execute the discriminator for Photo (fake)
example_disc_out_photo = photo_discriminator(example_gen_output_photo_fake, training=False)


# We present results, as the network is not trained, the outputs are not good, 
# but we modify the scala to be able to have some example images

plt.figure(figsize=(10,10))

# Input Photo
plt.subplot(2,4,1)
plt.imshow(photo, vmin=0, vmax=255) 

# Fake Monet
plt.subplot(2,4,2)
#m = example_gen_output_monet_fake[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100 
plt.imshow(example_gen_output_monet_fake[0,...]*contrast) 

# Photo Cycle
plt.subplot(2,4,3)
#m = example_gen_output_photo_cycle[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_photo_cycle[0,...]*contrast) 

# Monet discriminator result
#plt.subplot(2,4,4)
#m = example_disc_out_monet[0,...,-1].numpy()
#print(np.min(m), np.max(m))
#contrast = 1000
#im = plt.imshow(m*contrast, vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)



# Input Monet
plt.subplot(2,4,5)
plt.imshow(monet, vmin=0, vmax=255) 

# Fake Photo
plt.subplot(2,4,6)
#m = example_gen_output_photo_fake[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_photo_fake[0,...]*contrast) 

# Monet Cycle
plt.subplot(2,4,7)
#m = example_gen_output_monet_cycle[0,...].numpy()
#print(np.min(m), np.max(m))
contrast = 100
plt.imshow(example_gen_output_monet_cycle[0,...]*contrast) 

# Photo discriminator result  
#plt.subplot(2,4,8)
#m = example_disc_out_photo[0,...,-1].numpy()
#print(np.min(m), np.max(m))
#contrast = 1000
#im = plt.imshow(contrast, vmin=-20, vmax=20, cmap='RdBu_r')
#plt.colorbar(im,fraction=0.046, pad=0.04)

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=15,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

In [None]:
with strategy.scope():#并行运算
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
        #这里用交叉熵，但吴恩达最新教程推荐Mse(可以试着换)
        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.4

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True,
                                                  reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return Lambda_cycle * loss1

with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return Lambda_identity * loss

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(0.002, decay = 0, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(0.002, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )
    
#cycle_gan_model.fit(
    #full_dataset,
    #epochs=EPOCHS,
    #steps_per_epoch=(max(n_monet_samples, n_photo_samples)//BATCH_SIZE),
#)#这里是训练啦

In [None]:
%%time
fids=[]
best_fid=999999999
for epoch in range(1,32):

    print("Epoch = ",epoch)
    hist=cycle_gan_model.fit(full_dataset,steps_per_epoch=1500, epochs=1).history
    cur_fid=FID(fast_photo_ds,monet_generator)
#     disc_m_loss1.append(hist["monet_disc_loss1"][0][0][0])
#     disc_p_loss1.append(hist["photo_disc_loss1"][0][0][0])
    fids.append(cur_fid)
    print("After epoch #{} FID = {}\n".format(epoch,cur_fid))
    if cur_fid<best_fid:
            best_fid=cur_fid
            monet_generator.save('monet_generator.h5')


# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30, epochs=EPOCHS).history
# hist=cycle_gan_model.fit(gan_ds,steps_per_epoch=30,validation_data=([1]), epochs=3).history

In [None]:
# plt.plot(disc_m_loss1, label='monet_disc_loss1')
# plt.plot(disc_p_loss1, label='photo_disc_loss1')
plt.plot(np.array(fids), label='FID')

plt.legend()
plt.show()

In [None]:
# cycle_gan_model.built = True
# cycle_gan_model.load_weights('monet.h5')
with strategy.scope():
    monet_generator = tf.keras.models.load_model('monet_generator.h5')

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(my_photo_ds.take(5)):
    prediction = monet_generator(img, training=False)
    cycledphoto = photo_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 2].set_title("Cycled Photo")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
_, ax = plt.subplots(5, 3, figsize=(32, 32))
for i, img in enumerate(my_monet_ds.take(5)):
    prediction = photo_generator(img, training=False)
    cycledphoto = monet_generator(prediction, training=False)
    prediction = (prediction * 127.5 + 127.5)[0].numpy().astype(np.uint8)
    cycledphoto = (cycledphoto * 127.5 + 127.5)[0].numpy().astype(np.uint8)

    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 2].imshow(cycledphoto)
    ax[i, 0].set_title("Input Monet")
    ax[i, 1].set_title("Generated Photo")
    ax[i, 2].set_title("Cycled Monet")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.show()

In [None]:
ds_iter = iter(my_photo_ds)
for n_sample in range(8):
        example_sample = next(ds_iter)
        generated_sample = monet_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
ds_iter = iter(my_monet_ds)
for n_sample in range(8):

        example_sample = next(ds_iter)
        generated_sample = photo_generator(example_sample)
        
        f = plt.figure(figsize=(24, 24))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
import PIL
def predict_and_save(input_ds, generator_model, output_path):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy() # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)   # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(f'{output_path}{str(i)}.jpg')
        i += 1

os.makedirs('../images/') # Create folder to save generated images

predict_and_save(load_dataset(PHOTO_FILENAMES).batch(1), monet_generator, '../images/')

In [None]:
shutil.make_archive('/kaggle/working/images/', 'zip', '../images')

print(f"Generated samples: {len([name for name in os.listdir('../images/') if os.path.isfile(os.path.join('../images/', name))])}")