In [None]:
import tensorflow as tf
import os
import glob
import random
import matplotlib.pyplot as plt

IMG_SIZE = 256
BATCH_SIZE = 1
CHANNELS = 3
# MAX_PHOTOS = 3
AUTOTUNE = tf.data.AUTOTUNE
random.seed(666)

# Data Loading
photo_ds_path = './photo_jpg'
monet_ds_path = './monet_jpg'

photo_images = sorted(glob.glob(os.path.join(photo_ds_path, '*.jpg')))
monet_images = sorted(glob.glob(os.path.join(monet_ds_path, '*.jpg')))

# if len(photo_images) > MAX_PHOTOS:
#     photo_images = sorted(random.sample(photo_images, MAX_PHOTOS))
#     print(f'Randomly sampled {len(photo_images)} photos for training')

comparison_photo_paths = random.sample(photo_images, min(10, len(photo_images)))
epoch_snapshots = {}
epoch_loss_history = {"generator": [], "discriminator": []}

monet_files_tensor = tf.constant(monet_images)
monet_count = tf.cast(tf.shape(monet_files_tensor)[0], tf.int64)

def load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=CHANNELS)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    img = (tf.cast(img, tf.float32) / 127.5) - 1.0
    return img

def denormalize_image(img):
    img = (img + 1.0) / 2.0
    return tf.clip_by_value(img, 0.0, 1.0)

def map_pair(idx, photo_path):
    monet_idx = tf.math.floormod(idx, monet_count)
    monet_path = tf.gather(monet_files_tensor, monet_idx)
    photo_img = load_image(photo_path)
    monet_img = load_image(monet_path)
    return photo_img, monet_img

def generate_monet_batch(photo_paths):
    outputs = []
    for path in photo_paths:
        photo = load_image(path)
        monet_pred = monet_generator(tf.expand_dims(photo, axis=0), training=False)
        monet_img = tf.squeeze(monet_pred, axis=0)
        outputs.append(denormalize_image(monet_img).numpy())
    return outputs

def show_photo_to_monet(photo_paths=None, num_examples=3, monet_images=None, title='Original Image to Monet style Image'):
    selected_paths = list(photo_paths)
    sample_count = len(selected_paths)
    fig, axes = plt.subplots(sample_count, 2, figsize=(5, 2.5 * sample_count))
    if sample_count == 1:
        axes = [axes]
    for row_idx, path in enumerate(selected_paths):
        row_axes = axes[row_idx]
        photo_disp = denormalize_image(load_image(path)).numpy()
        monet_disp = monet_images[row_idx]
        row_axes[0].imshow(photo_disp)
        row_axes[0].set_title('Original Photo')
        row_axes[0].axis('off')
        row_axes[1].imshow(monet_disp)
        row_axes[1].set_title('Monet-style Output')
        row_axes[1].axis('off')
    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=(0, 0, 1, 0.985))
    plt.show()

def show_epoch_progression(photo_paths, snapshots):
    epochs = sorted(snapshots.keys())
    sample_count = len(photo_paths)
    fig, axes = plt.subplots(sample_count, len(epochs) + 1, figsize=(4 * (len(epochs) + 1), 4 * sample_count))
    if sample_count == 1:
        axes = [axes]
    original_images = [denormalize_image(load_image(path)).numpy() for path in photo_paths]
    for row_idx, photo_img in enumerate(original_images):
        row_axes = axes[row_idx]
        row_axes[0].imshow(photo_img)
        row_axes[0].set_title('Original Photo')
        row_axes[0].axis('off')
        for col_idx, epoch in enumerate(epochs, start=1):
            ax = row_axes[col_idx]
            ax.imshow(snapshots[epoch][row_idx])
            ax.set_title(f'Epoch {epoch}')
            ax.axis('off')
    fig.suptitle('Original vs. Monet-style progression', fontsize=14)
    plt.tight_layout(rect=(0, 0, 1, 0.97))
    plt.show()

def plot_loss_history(loss_history):
    gen_losses = loss_history["generator"]
    disc_losses = loss_history["discriminator"]
    epochs = range(1, len(gen_losses) + 1)
    plt.figure(figsize=(8, 4))
    plt.plot(epochs, gen_losses, label='Generator Loss')
    plt.plot(epochs, disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Generator vs Discriminator Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.show()

dataset = tf.data.Dataset.from_tensor_slices(photo_images)
dataset = dataset.shuffle(buffer_size=len(photo_images), reshuffle_each_iteration=True)
dataset = dataset.enumerate()
dataset = dataset.map(map_pair, num_parallel_calls=AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(AUTOTUNE)

# Model Definitions
def build_generator():
    inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, CHANNELS))
    x = tf.keras.layers.Conv2D(64, 7, padding='same')(inputs)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2D(128, 3, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2D(256, 3, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    outputs = tf.keras.layers.Conv2D(3, 7, padding='same', activation='tanh')(x)
    return tf.keras.Model(inputs, outputs, name='generator')

def build_discriminator():
    inputs = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, CHANNELS))
    x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    outputs = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
    return tf.keras.Model(inputs, outputs, name='discriminator')

monet_generator = build_generator()
photo_generator = build_generator()
monet_discriminator = build_discriminator()
photo_discriminator = build_discriminator()

# Loss Functions and Optimizers
mse = tf.keras.losses.MeanSquaredError()
optimizer_G = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
optimizer_D = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)

gen_vars = monet_generator.trainable_variables + photo_generator.trainable_variables
disc_vars = monet_discriminator.trainable_variables + photo_discriminator.trainable_variables

@tf.function
def train_step(real_photo, real_monet):
    with tf.GradientTape(persistent=True) as tape:
        fake_monet = monet_generator(real_photo, training=True)
        fake_photo = photo_generator(real_monet, training=True)

        monet_real_logits = monet_discriminator(real_monet, training=True)
        monet_fake_logits = monet_discriminator(fake_monet, training=True)
        photo_real_logits = photo_discriminator(real_photo, training=True)
        photo_fake_logits = photo_discriminator(fake_photo, training=True)

        valid = tf.ones_like(monet_real_logits)
        fake = tf.zeros_like(monet_fake_logits)

        loss_G = mse(valid, monet_fake_logits) + mse(valid, photo_fake_logits)
        loss_D_monet = mse(valid, monet_real_logits) + mse(fake, monet_fake_logits)
        loss_D_photo = mse(valid, photo_real_logits) + mse(fake, photo_fake_logits)
        loss_D = loss_D_monet + loss_D_photo

    gen_grads = tape.gradient(loss_G, gen_vars)
    disc_grads = tape.gradient(loss_D, disc_vars)

    optimizer_G.apply_gradients(zip(gen_grads, gen_vars))
    optimizer_D.apply_gradients(zip(disc_grads, disc_vars))

    return loss_G, loss_D

# Training Loop
def train(dataset, epochs, visualize_interval):
    global epoch_snapshots, epoch_loss_history

    steps_per_epoch = len(photo_images) // BATCH_SIZE

    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        progbar = tf.keras.utils.Progbar(target=steps_per_epoch, verbose=1)
        for step, (real_photo, real_monet) in enumerate(dataset.take(steps_per_epoch), start=1):
            loss_G, loss_D = train_step(real_photo, real_monet)
            progbar.update(step, values=[('loss_G', float(loss_G)), ('loss_D', float(loss_D))])
        print(f"Epoch {epoch + 1} complete - loss_G: {float(loss_G):.3f}, loss_D: {float(loss_D):.3f}")

        epoch_loss_history["generator"].append(float(loss_G))
        epoch_loss_history["discriminator"].append(float(loss_D))

        if (epoch + 1) % visualize_interval == 0:
            monet_outputs = generate_monet_batch(comparison_photo_paths)
            epoch_snapshots[epoch + 1] = monet_outputs
            show_photo_to_monet(photo_paths=comparison_photo_paths, monet_images=monet_outputs, title=f'Epoch {epoch + 1} progress')

# Train model
train(dataset, epochs=1000, visualize_interval=100)

plot_loss_history(epoch_loss_history)

# Visualize progression over the recorded epochs
show_epoch_progression(comparison_photo_paths, epoch_snapshots)
