<a href="https://colab.research.google.com/github/vsuhas9/StyleTransfer/blob/dev-suhas/Cycle_GAN_for_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Necessary Libraries and Helper Functions

In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
! pip install tensorflow_addons
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
from tqdm import tqdm

# Helper Functions for Generator and Discriminator
def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    result.add(layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                      kernel_initializer=initializer, use_bias=False))
    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    result.add(layers.ReLU())
    return result

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (611 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.8/611.8 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow_addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow_addons
Successfully installed tensorflow_addons-0.23.0 typeguard-2.13.3



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



#Defining the Generator and Descriminator Stack

In [None]:
# Generator Model
def Generator(img_shape=[256, 256, 3]):
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = layers.Input(shape=img_shape)

    down_stack = [
        downsample(64, 4, apply_instancenorm=False),  # (bs, 128, 128, 64)
        downsample(128, 4),  # (bs, 64, 64, 128)
        downsample(256, 4),  # (bs, 32, 32, 256)
        downsample(512, 4),  # (bs, 16, 16, 512)
        downsample(512, 4),  # (bs, 8, 8, 512)
        downsample(512, 4),  # (bs, 4, 4, 512)
        downsample(512, 4),  # (bs, 2, 2, 512)
        downsample(512, 4),  # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        upsample(512, 4),  # (bs, 16, 16, 1024)
        upsample(256, 4),  # (bs, 32, 32, 512)
        upsample(128, 4),  # (bs, 64, 64, 256)
        upsample(64, 4),  # (bs, 128, 128, 128)
    ]

    last = layers.Conv2DTranspose(3, 4, strides=2, padding='same',
                                  kernel_initializer=initializer, activation='tanh')  # (bs, 256, 256, 3)

    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])
    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Discriminator Model
def Discriminator(img_shape=[256, 256, 3]):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=img_shape, name='input_image')
    x = inp

    x = downsample(64, 4, False)(x)  # (bs, 128, 128, 64)
    x = downsample(128, 4)(x)  # (bs, 64, 64, 128)
    x = downsample(256, 4)(x)  # (bs, 32, 32, 256)

    x = layers.ZeroPadding2D()(x)  # (bs, 33, 33, 256)
    x = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(x)  # (bs, 30, 30, 512)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(x)
    x = layers.LeakyReLU()(x)

    x = layers.ZeroPadding2D()(x)  # (bs, 31, 31, 512)
    x = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(x)  # (bs, 28, 28, 1)

    return tf.keras.Model(inputs=inp, outputs=x)


# Step 3: Defining Loss Functions

In [None]:
def discriminator_loss(real, generated):
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss * 0.5

def generator_loss(generated):
    return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    loss = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss

def identity_loss(real_image, same_image, LAMBDA):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss


# Step - 4: Cycle GAN CLass

In [None]:
class CycleGAN(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGAN, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle

    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGAN, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def train_step(self, batch_data):
        real_monet, real_photo = batch_data

        with tf.GradientTape(persistent=True) as tape:
            # Photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # Monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # Generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # Discriminator used to check, inputting real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # Discriminator used to check, inputting fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # Evaluate generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # Evaluate total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # Evaluate total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # Evaluate discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        monet_discriminator_gradients = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients, self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients, self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients, self.p_disc.trainable_variables))

        return {
            "total_loss": total_monet_gen_loss + total_photo_gen_loss + monet_disc_loss + photo_disc_loss,
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

    def generate(self, image):
        return self.m_gen(tf.expand_dims(image, axis=0), training=False)

    def load(
        self,
        filepath
    ):
        self.m_gen.load_weights(filepath.replace('model_name', 'm_gen'))
        self.p_gen.load_weights(filepath.replace('model_name', 'p_gen'))
        self.m_disc.load_weights(filepath.replace('model_name', 'm_disc'))
        self.p_disc.load_weights(filepath.replace('model_name', 'p_disc'))

    def save(
        self,
        filepath
    ):
        self.m_gen.save_weights(filepath.replace('model_name', 'm_gen'))
        self.p_gen.save_weights(filepath.replace('model_name', 'p_gen'))
        self.m_disc.save_weights(filepath.replace('model_name', 'm_disc'))
        self.p_disc.save_weights(filepath.replace('model_name', 'p_disc'))


# Step - 5: Loading and preprocessing images

In [None]:
def load_and_preprocess_image(path, img_shape):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=img_shape[-1])
    img = tf.image.resize(img, img_shape[:2])
    img = (img - 127.5) / 127.5  # Normalize the image to [-1, 1]
    return img

def create_dataset(target_paths, style_paths, img_shape, batch_size):
    target_images = [load_and_preprocess_image(path, img_shape) for path in target_paths]
    style_images = [load_and_preprocess_image(path, img_shape) for path in style_paths]

    target_dataset = tf.data.Dataset.from_tensor_slices(target_images)
    style_dataset = tf.data.Dataset.from_tensor_slices(style_images)

    # Zip the two datasets and batch them
    dataset = tf.data.Dataset.zip((target_dataset, style_dataset)).batch(batch_size)
    return dataset


# Step 6: Initialize and Compile CycleGAN

In [None]:
img_shape = [256, 256, 3]

model = CycleGAN(
    monet_generator=Generator(img_shape),
    photo_generator=Generator(img_shape),
    monet_discriminator=Discriminator(img_shape),
    photo_discriminator=Discriminator(img_shape),
    lambda_cycle=10
)

model.compile(
    m_gen_optimizer=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),
    p_gen_optimizer=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),
    m_disc_optimizer=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),
    p_disc_optimizer=tf.keras.optimizers.Adam(1e-4, beta_1=0.5),
    gen_loss_fn=generator_loss,
    disc_loss_fn=discriminator_loss,
    cycle_loss_fn=calc_cycle_loss,
    identity_loss_fn=identity_loss
)


# Step - 7: Loading and Training the model

In [None]:
def list_files(folder_path):
    files = []
    for entry in os.listdir(folder_path):
        if os.path.isfile(os.path.join(folder_path, entry)):
            files.append( folder_path + entry)
    return files

# Example usage
target_folder_path = '/content/drive/MyDrive/shared/gan-getting-started/photo_jpg/'
style_folder_path = '/content/drive/MyDrive/shared/gan-getting-started/monet_jpg/'

try:
  target_image_paths =  list_files(target_folder_path)
  style_image_paths = list_files(style_folder_path)[:25]
  target_image_paths = target_image_paths[:len(style_image_paths)*2]

  print(str(len(target_image_paths)) + " Target Images loaded & " + str(len(style_image_paths)) + " Style Images loaded" )
except:
  print("Error loading the files")


batch_size = 32

# Create dataset
dataset = create_dataset(target_image_paths, style_image_paths, img_shape, batch_size)

# Train the model
epochs = 50
history = model.fit(dataset, epochs=epochs, batch_size=batch_size)


50 Target Images loaded & 25 Style Images loaded
Epoch 1/50


# Step - 8: Display and Store the Images

In [None]:
import matplotlib.pyplot as plt
import os

def generate_display_save(target_dataset, model, num_samples, output_dir):
    os.makedirs(output_dir, exist_ok=True)  # Create the output directory
    dataset_iter = iter(target_dataset)
    plt.figure(figsize=(10, 10))

    for i in range(num_samples):
        img, _ = next(dataset_iter)  # Only use the target image
        prediction = model.generate(img)
        prediction = tf.squeeze(prediction).numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)  # Rescale the pixel values

        # Display the image
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(prediction)
        plt.axis('off')

        # Save the image
        plt.imsave(os.path.join(output_dir, f'image_{i:04d}.jpg'), prediction)

    plt.show()

# Example usage
generate_display_save(dataset, model, num_samples=5, output_dir='output')  # Adjust num_samples as needed
