In [None]:
gcs_path = KaggleDatasets().get_gcs_path()

In [None]:
Config.monet_filenames = tf.io.gfile.glob(str(gcs_path + Config.monet_path))
print('Monet TFRecord Files:', len(Config.monet_filenames))

Config.photo_filenames = tf.io.gfile.glob(str(gcs_path + Config.photo_path))
print('Photo TFRecord Files:', len(Config.photo_filenames))

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*Config.image_size, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(Config.monet_filenames, labeled=True).batch(1)
photo_ds = load_dataset(Config.photo_filenames, labeled=True).batch(1)

In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

In [None]:
# Plot 1 image from each set
plt.figure(figsize=(10, 5))

# Monet image
ax = plt.subplot(1, 2, 1)
plt.imshow((example_monet[0].numpy() * 0.5 + 0.5))  # Unnormalize if in [-1, 1]
plt.title("Monet")
plt.axis("off")

# Photo image
ax = plt.subplot(1, 2, 2)
plt.imshow((example_photo[0].numpy() * 0.5 + 0.5))  # Unnormalize if in [-1, 1]
plt.title("Photo")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential([
        Conv2D(filters, size, strides=2, padding='same',
               kernel_initializer=initializer, use_bias=False),
    ])
    if apply_batchnorm:
        result.add(BatchNormalization())
        
    result.add(LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = Sequential([
        Conv2DTranspose(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer, use_bias=False),
        BatchNormalization()
    ])
    if apply_dropout:
        result.add(Dropout(0.5))
    result.add(ReLU())
    return result

In [None]:
def Generator():
    inputs = Input(shape=[*Config.image_size, 3])
    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
        downsample(128, 4),  # (bs, 64, 64, 128)
        downsample(256, 4),  # (bs, 32, 32, 256)
    ]
    up_stack = [
        upsample(128, 4),  # (bs, 64, 64, 128)
        upsample(64, 4),   # (bs, 128, 128, 64)
    ]
    last = Conv2DTranspose(3, 4, strides=2, padding='same',
                            kernel_initializer=tf.random_normal_initializer(0., 0.02),
                            activation='tanh')  # (bs, 256, 256, 3)

    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = Concatenate()([x, skip])

    x = last(x)
    return Model(inputs=inputs, outputs=x)

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = Input(shape=[*Config.image_size, 3], name='input_image')
    x = downsample(64, 4, False)(inp)
    x = downsample(128, 4)(x)
    x = downsample(256, 4)(x)
    x = ZeroPadding2D()(x)
    x = Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    x = ZeroPadding2D()(x)
    x = Conv2D(1, 4, strides=1, kernel_initializer=initializer)(x)
    return Model(inputs=inp, outputs=x)

In [None]:
generator_g = Generator()  # Monet → Photo
generator_f = Generator()  # Photo → Monet

discriminator_x = Discriminator()  # Discriminator for Monet
discriminator_y = Discriminator()  # Discriminator for Photo

In [None]:
generator_f.summary()
discriminator_x.summary()

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    return (real_loss + generated_loss) * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

def cycle_loss(real_image, cycled_image):
    return tf.reduce_mean(tf.abs(real_image - cycled_image)) * 10.0

def identity_loss(real_image, same_image):
    return tf.reduce_mean(tf.abs(real_image - same_image)) * 5.0

In [None]:
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = cycle_loss(real_x, cycled_x) + cycle_loss(real_y, cycled_y)
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y) * Config.identity_weight
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Apply gradients
    generator_g_grad = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
    generator_f_grad = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)

    discriminator_x_grad = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
    discriminator_y_grad = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)

    generator_g_optimizer.apply_gradients(zip(generator_g_grad, generator_g.trainable_variables))
    generator_f_optimizer.apply_gradients(zip(generator_f_grad, generator_f.trainable_variables))
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_grad, discriminator_x.trainable_variables))
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_grad, discriminator_y.trainable_variables))

In [None]:
def visualize_epoch_result(epoch):
    # Visualize intermediate epochs
    if (epoch + 1) % 5 == 0:
        os.makedirs("cycle_outputs", exist_ok=True)
        
        sample_photo = next(iter(photo_ds))
        prediction = generator_g(sample_photo)
    
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1)
        plt.imshow((sample_photo[0] * 0.5 + 0.5).numpy())
        plt.title("Input")
        plt.axis('off')
    
        plt.subplot(1, 2, 2)
        plt.imshow((prediction[0] * 0.5 + 0.5).numpy())
        plt.title(f"Epoch {epoch+1}")
        plt.axis('off')
    
        plt.tight_layout()
        plt.savefig(f"cycle_outputs/epoch_{epoch+1}.png")
        plt.show()

In [None]:
def recalc_losses():
    sample_photo = next(iter(photo_ds))
    sample_monet = next(iter(monet_ds))

    fake_y = generator_g(sample_photo, training=False)
    cycled_x = generator_f(fake_y, training=False)
    fake_x = generator_f(sample_monet, training=False)
    cycled_y = generator_g(fake_x, training=False)
    same_x = generator_f(sample_photo, training=False)
    same_y = generator_g(sample_monet, training=False)

    disc_real_x = discriminator_x(sample_photo, training=False)
    disc_fake_x = discriminator_x(fake_x, training=False)
    disc_real_y = discriminator_y(sample_monet, training=False)
    disc_fake_y = discriminator_y(fake_y, training=False)

    # Calculate losses for visualization
    g_loss = generator_loss(disc_fake_y)
    f_loss = generator_loss(disc_fake_x)
    cyc_loss = cycle_loss(sample_photo, cycled_x) + cycle_loss(sample_monet, cycled_y)
    id_loss = identity_loss(sample_monet, same_y) + identity_loss(sample_photo, same_x)
    dx_loss = discriminator_loss(disc_real_x, disc_fake_x)
    dy_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Append losses
    gen_g_losses.append(g_loss.numpy())
    gen_f_losses.append(f_loss.numpy())
    disc_x_losses.append(dx_loss.numpy())
    disc_y_losses.append(dy_loss.numpy())
    cycle_losses.append(cyc_loss.numpy())
    identity_losses.append(id_loss.numpy())

In [None]:
initial_lr = 2e-4
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=[20],  # after 20 epochs
    values=[initial_lr, initial_lr * 0.5]  # decay to half
)

# apply clipnorm=1.0 to make backpropagation more stable
generator_g_optimizer = tf.keras.optimizers.Adam(lr_schedule, beta_1=0.5, clipnorm=1.0)
generator_f_optimizer = tf.keras.optimizers.Adam(lr_schedule, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(lr_schedule, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(lr_schedule, beta_1=0.5)

In [None]:
gen_g_losses = []
gen_f_losses = []
disc_x_losses = []
disc_y_losses = []
cycle_losses = []
identity_losses = []

for epoch in range(Config.epochs):
    for image_x, image_y in tf.data.Dataset.zip((photo_ds, monet_ds)):
        train_step(image_x, image_y)
        
    recalc_losses()
    visualize_epoch_result(epoch)
    
    print(f"Epoch {epoch+1}/{Config.epochs} done.")

In [None]:
epochs = range(1, len(gen_g_losses) + 1)

plt.figure(figsize=(12, 6))

plt.plot(epochs, gen_g_losses, label='Generator G Loss (Photo→Monet)')
plt.plot(epochs, gen_f_losses, label='Generator F Loss (Monet→Photo)')
plt.plot(epochs, disc_x_losses, label='Discriminator X Loss (Photo)')
plt.plot(epochs, disc_y_losses, label='Discriminator Y Loss (Monet)')
plt.plot(epochs, cycle_losses, label='Cycle Consistency Loss')
plt.plot(epochs, identity_losses, label='Identity Loss')

plt.title("CycleGAN Training Losses Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input, training=False)
    plt.figure(figsize=(12, 6))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Translated Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow((display_list[i] * 0.5 + 0.5).numpy())
        plt.axis('off')
    plt.show()

In [None]:
for i, batch in enumerate(photo_ds):
    generate_images(generator_g, batch)
    if i == 4:
        break

In [None]:
! mkdir ../images


In [None]:
i = 1
for img in photo_ds:
    prediction = generator_g(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = Image.fromarray(prediction)
    im.save("../images/" + str(i) + ".jpg")
    i += 1

In [None]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")
