In [12]:
import tensorflow as tf

tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)

In [13]:
from kaggle_datasets import KaggleDatasets

gcs_path = KaggleDatasets().get_gcs_path()
gcs_path

'gs://kds-028d4efa9379430e44a4e41606cd6bf572ebc886adcaa2fb4e31b394'

In [14]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
IMAGE_SIZE = [256, 256]

monet_filenames = tf.io.gfile.glob(gcs_path + '/monet_tfrec/*.tfrec')
photo_filenames = tf.io.gfile.glob(gcs_path + '/photo_tfrec/*.tfrec')



In [15]:
features = {'image': tf.io.FixedLenFeature([], tf.string)}

def read_tfrecord(example):
    image_data = tf.io.parse_single_example(example, features)
    image = image_data['image']
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image


monet_dataset = tf.data.TFRecordDataset(monet_filenames).map(read_tfrecord, num_parallel_calls=AUTOTUNE).batch(1)
photo_dataset = tf.data.TFRecordDataset(photo_filenames).map(read_tfrecord, num_parallel_calls=AUTOTUNE).batch(1)

In [27]:
monet_dataset_len = len(list(iter(monet_dataset)))
photo_dataset_len = len(list(photo_dataset.as_numpy_iterator()))
print(monet_dataset_len, photo_dataset_len)

300 7038


In [18]:
example = next(iter(monet_dataset))

In [19]:
example.shape

TensorShape([1, 256, 256, 3])

In [20]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

OUTPUT_CHANNELS = 3


class CycleGan(keras.Model):
    def __init__(self):
        super(CycleGan, self).__init__()
        self.monet_generator = self.create_generator()
        self.photo_generator = self.create_generator()
        self.monet_discriminator = self.create_discriminator()
        self.photo_discriminator = self.create_discriminator()
        self.lambda_cycle = 10

    def compile(self):
        super(CycleGan, self).compile()
        self.monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    def create_downsampler(self, filters, size, apply_instance_norm=True):
        model = keras.Sequential()
        model.add(
            layers.Conv2D(
                filters,
                size,
                strides=2,
                padding="same",
                use_bias=False,
                kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
            )
        )
        if apply_instance_norm:
            model.add(
                tfa.layers.InstanceNormalization(
                    gamma_initializer=keras.initializers.RandomNormal(
                        mean=0.0, stddev=0.02
                    )
                )
            )
        model.add(layers.LeakyReLU())
        return model

    def create_upsampler(self, filters, size, apply_dropout=False):
        model = keras.Sequential()
        model.add(
            layers.Conv2DTranspose(
                filters,
                size,
                strides=2,
                padding="same",
                use_bias=False,
                kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
            )
        )
        model.add(
            tfa.layers.InstanceNormalization(
                gamma_initializer=tf.random_normal_initializer(0.0, 0.02)
            )
        )
        if apply_dropout:
            model.add(layers.Dropout(0.5))
        model.add(layers.ReLU())
        return model

    def create_generator(self):
        downsampler_stack = [
            self.create_downsampler(64, 4, apply_instance_norm=False),
            self.create_downsampler(128, 4),
            self.create_downsampler(256, 4),
        ] + [self.create_downsampler(512, 4) for i in range(5)]
        upsampler_stack = [
            self.create_upsampler(512, 4, apply_dropout=True) for i in range(3)
        ] + [
            self.create_upsampler(512, 4),
            self.create_upsampler(256, 4),
            self.create_upsampler(128, 4),
            self.create_upsampler(64, 4),
        ]
        input_layer = layers.Input(shape=[256, 256, 3])
        x = input_layer
        skips = []
        for downsampler in downsampler_stack:
            x = downsampler(x)
            skips.append(x)
        skips = reversed(skips[:-1])
        for upsampler, skip_layer in zip(upsampler_stack, skips):
            x = upsampler(x)
            x = layers.Concatenate()([x, skip_layer])
        last_layer = layers.Conv2DTranspose(
            OUTPUT_CHANNELS,
            4,
            strides=2,
            padding="same",
            kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
            activation="tanh",
        )
        x = last_layer(x)
        return keras.Model(inputs=input_layer, outputs=x)

    def create_discriminator(self):
        input_layer = layers.Input(shape=[256, 256, 3], name="input_image")
        x = input_layer
        downsampler1 = self.create_downsampler(64, 4, False)(x)
        downsampler2 = self.create_downsampler(128, 4)(downsampler1)
        downsampler3 = self.create_downsampler(256, 4)(downsampler2)
        zero_pad1 = layers.ZeroPadding2D()(downsampler3)
        conv_layer = layers.Conv2D(
            512,
            4,
            strides=1,
            use_bias=False,
            kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
        )(zero_pad1)
        normalization_layer1 = tfa.layers.InstanceNormalization(
            gamma_initializer=tf.random_normal_initializer(0.0, 0.02)
        )(conv_layer)
        leaky_relu_layer = layers.LeakyReLU()(normalization_layer1)
        zero_pad2 = layers.ZeroPadding2D()(leaky_relu_layer)
        last_layer = layers.Conv2D(
            1, 4, strides=1, kernel_initializer=tf.random_normal_initializer(0.0, 0.02)
        )(zero_pad2)
        return tf.keras.Model(inputs=input_layer, outputs=last_layer)

    def discriminator_loss_fn(self, real, fake):
        real_loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.ones_like(real), real)
        fake_loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.zeros_like(fake), fake)
        return (real_loss + fake_loss) / 2

    def generator_loss_fn(self, generated_image):
        return tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.ones_like(generated_image), generated_image)

    def cycle_loss_fn(self, image, cycled_image, lambda_cycle):
        return tf.reduce_mean(tf.abs(image - cycled_image)) * lambda_cycle

    def identity_loss_fn(self, real_photo, photo, lambda_cycle):
        return tf.reduce_mean(tf.abs(real_photo - photo)) * lambda_cycle / 2

    def train_step(self, batch_data):
        real_monet, real_photo = batch_data

        with tf.GradientTape(persistent=True) as tape:
            fake_monet = self.monet_generator(real_photo, training=True)
            cycled_photo = self.photo_generator(fake_monet, training=True)
            fake_photo = self.photo_generator(real_monet, training=True)
            cycled_monet = self.monet_generator(fake_photo, training=True)

            monet1 = self.monet_generator(real_monet, training=True)
            photo1 = self.photo_generator(real_photo, training=True)

            monet_real_discriminated = self.monet_discriminator(
                real_monet, training=True
            )
            monet_fake_discriminated = self.monet_discriminator(
                fake_monet, training=True
            )
            photo_real_discriminated = self.photo_discriminator(
                real_photo, training=True
            )
            photo_fake_discriminated = self.photo_discriminator(
                fake_photo, training=True
            )

            monet_generator_loss = self.generator_loss_fn(monet_fake_discriminated)
            photo_generator_loss = self.generator_loss_fn(photo_fake_discriminated)
            cycle_loss = self.cycle_loss_fn(
                real_monet, cycled_monet, self.lambda_cycle
            ) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
            total_monet_generator_loss = (
                monet_generator_loss
                + cycle_loss
                + self.identity_loss_fn(real_monet, monet1, self.lambda_cycle)
            )
            total_photo_generator_loss = (
                photo_generator_loss
                + cycle_loss
                + self.identity_loss_fn(real_photo, photo1, self.lambda_cycle)
            )
            monet_discriminator_loss = self.discriminator_loss_fn(
                monet_real_discriminated, monet_fake_discriminated
            )
            photo_discriminator_loss = self.discriminator_loss_fn(
                photo_real_discriminated, photo_fake_discriminated
            )
        monet_generator_gradients = tape.gradient(
            total_monet_generator_loss, self.monet_generator.trainable_variables
        )
        photo_generator_gradients = tape.gradient(
            total_photo_generator_loss, self.photo_generator.trainable_variables
        )
        monet_discriminator_gradients = tape.gradient(
            monet_discriminator_loss, self.monet_discriminator.trainable_variables
        )
        photo_discriminator_gradients = tape.gradient(
            photo_discriminator_loss, self.photo_discriminator.trainable_variables
        )
        self.monet_generator_optimizer.apply_gradients(
            zip(monet_generator_gradients, self.monet_generator.trainable_variables)
        )
        self.photo_generator_optimizer.apply_gradients(
            zip(photo_generator_gradients, self.photo_generator.trainable_variables)
        )
        self.monet_discriminator_optimizer.apply_gradients(
            zip(
                monet_discriminator_gradients,
                self.monet_discriminator.trainable_variables,
            )
        )
        self.photo_discriminator_optimizer.apply_gradients(
            zip(
                photo_discriminator_gradients,
                self.photo_discriminator.trainable_variables,
            )
        )
        return {
            "monet_generator_loss": total_monet_generator_loss,
            "photo_generator_loss": total_photo_generator_loss,
            "monet_discriminator_loss": monet_discriminator_loss,
            "photo_discriminator_loss": photo_discriminator_loss,
        }

In [21]:
with strategy.scope():
    cycle_gan_model = CycleGan()
    cycle_gan_model.compile()

In [29]:
cycle_gan_model.fit(
    tf.data.Dataset.zip((monet_dataset.repeat(photo_dataset_len//monet_dataset_len), photo_dataset)),
    steps_per_epoch=300,
    epochs=30
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30


OutOfRangeError: 9 root error(s) found.
  (0) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_9/else/_84/cond_9/IteratorGetNext}}]]
  (1) Cancelled: {{function_node __inference_train_function_156099}} Operation was cancelled
	 [[{{node cond_15/else/_150/cond_15/IteratorGetNext}}]]
  (2) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[tpu_compile_succeeded_assert/_17447968332845859611/_5/_221]]
  (3) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[Pad_12/paddings/_140]]
  (4) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[tpu_compile_succeeded_assert/_17447968332845859611/_5/_233]]
  (5) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[TPUReplicate/_compile/_6444032947155159481/_4/_170]]
  (6) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[cluster_train_function/_execute_7_0/_289]]
  (7) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
	 [[strided_slice_21/_204]]
  (8) Out of range: {{function_node __inference_train_function_156099}} End of sequence
	 [[{{node cond_10/else/_95/cond_10/IteratorGetNext}}]]
0 successful operations.
0 derived errors ignored.