In [10]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, LeakyReLU, Activation, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from instancenormalization import InstanceNormalization
from tensorflow.keras.layers import Input, UpSampling2D
from tensorflow.keras.models import Model
from utils import *

# required function in defining Generator and Discriminator models

In [3]:
# Down sample fucntion
def downsampling(in_layer: tf.Tensor, num_filters: int, kernel_size: int = 4, strides: int = 2) -> tf.Tensor:
    """
    Downsamples an input tensor using a Conv2D layer, followed by LeakyReLU activation and 
    InstanceNormalization.

    Args:
        in_layer (tf.Tensor): Input tensor to be downsampled.
        num_filters (int): Number of filters for the Conv2D layer.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 4.
        strides (int, optional): Stride size for the convolution operation. Defaults to 2.

    Returns:
        tf.Tensor: The downsampled output tensor after applying convolution, activation, and normalization.
    """
    downsampled = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(in_layer)
    downsampled = LeakyReLU(alpha=0.2)(downsampled)
    downsampled = InstanceNormalization()(downsampled)
    return downsampled

# Up sample function
def upsampling(in_layer: tf.Tensor, skip_layer: tf.Tensor, num_filters: int, kernel_size: int = 4, strides: int = 1, dropout_rate: float = 0) -> tf.Tensor:
    """
    Upsamples an input tensor using UpSampling2D and Conv2D layers, with optional dropout and 
    InstanceNormalization, followed by concatenation with a skip connection.

    Args:
        in_layer (tf.Tensor): Input tensor to be upsampled.
        skip_layer (tf.Tensor): Tensor to concatenate as a skip connection with the upsampled tensor.
        num_filters (int): Number of filters for the Conv2D layer.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 4.
        strides (int, optional): Stride size for the convolution operation. Defaults to 1.
        dropout_rate (float, optional): Dropout rate (0 means no dropout). Defaults to 0.

    Returns:
        tf.Tensor: The upsampled output tensor after applying convolution, normalization, and concatenation.
    """
    upsampled = UpSampling2D(size=2)(in_layer)
    upsampled = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same', activation='relu')(upsampled)
    if dropout_rate:
        upsampled = Dropout(dropout_rate)(upsampled)
    upsampled = InstanceNormalization()(upsampled)
    upsampled = Concatenate()([upsampled, skip_layer])
    return upsampled

# define U-Net shape generative model

In [4]:
def build_generator(img_shape: tuple, in_channels: int = 3, num_filters: int = 32) -> tf.keras.Model:
    """
    Builds a U-Net style generator model with downsampling and upsampling layers, often used for 
    image generation tasks.

    Args:
        img_shape (tuple): Shape of the input image (height, width, channels).
        in_channels (int, optional): Number of channels in the output image. Defaults to 3.
        num_filters (int, optional): Base number of filters for the downsampling layers. Defaults to 32.

    Returns:
        tf.keras.Model: The generator model built with U-Net architecture.
    """
    # image shape
    input_layer = Input(shape=img_shape)
    
    # downsampling in U-Net model
    down_sample_1 = downsampling(in_layer=input_layer, num_filters=num_filters)
    down_sample_2 = downsampling(in_layer=down_sample_1, num_filters=2 * num_filters)
    down_sample_3 = downsampling(in_layer=down_sample_2, num_filters=4 * num_filters)
    bottleneck = downsampling(in_layer=down_sample_3, num_filters=8 * num_filters)
    
    # upsampling in U-Net model
    upsample_1 = upsampling(in_layer=bottleneck, skip_layer=down_sample_3, num_filters=4 * num_filters)
    upsample_2 = upsampling(in_layer=upsample_1, skip_layer=down_sample_2, num_filters=2 * num_filters)
    upsample_3 = upsampling(in_layer=upsample_2, skip_layer=down_sample_1, num_filters=num_filters)
    upsample_4 = UpSampling2D(size=2)(upsample_3)
    
    # output layer
    output_img = Conv2D(in_channels, kernel_size=4, strides=1, padding='same', activation='tanh')(upsample_4)
    
    # return the generative model
    return Model(input_layer, output_img)


# define the discriminator model

In [5]:
def disc_block(in_layer: tf.Tensor, num_filters: int, kernel_size: int = 4, instance_normalization: bool = True) -> tf.Tensor:
    """
    Builds a convolutional block with Conv2D, LeakyReLU activation, and optional InstanceNormalization, 
    commonly used in discriminator networks.

    Args:
        in_layer (tf.Tensor): Input tensor for the block.
        num_filters (int): Number of filters for the Conv2D layer.
        kernel_size (int, optional): Size of the convolutional kernel. Defaults to 4.
        instance_normalization (bool, optional): Whether to apply InstanceNormalization. Defaults to True.

    Returns:
        tf.Tensor: The output tensor after applying convolution, activation, and optional normalization.
    """
    disc_layer = Conv2D(num_filters, kernel_size=kernel_size, strides=2, padding='same')(in_layer)
    disc_layer = LeakyReLU(alpha=0.2)(disc_layer)
    if instance_normalization:
        disc_layer = InstanceNormalization()(disc_layer)
    return disc_layer

def build_discriminator(img_shape: tuple, num_filters: int = 64) -> tf.keras.Model:
    """
    Builds a discriminator model using multiple convolutional blocks and outputs a single-channel 
    feature map. The model uses a sequence of downsampling layers with increasing filter sizes.

    Args:
        img_shape (tuple): Shape of the input image (height, width, channels).
        num_filters (int, optional): Base number of filters for the first convolutional block. Defaults to 64.

    Returns:
        tf.keras.Model: The discriminator model built for distinguishing between real and generated images.
    """
    input_layer = Input(shape=img_shape)
    
    # First block, without instance normalization
    disc_block_1 = disc_block(input_layer, num_filters=num_filters, instance_normalization=False)
    
    # Subsequent blocks with increasing filters
    disc_block_2 = disc_block(disc_block_1, num_filters * 2)
    disc_block_3 = disc_block(disc_block_2, num_filters * 4)
    disc_block_4 = disc_block(disc_block_3, num_filters * 8)
    
    # Final output layer
    disc_output = Conv2D(1, kernel_size=4, strides=1, padding='same')(disc_block_4)
    
    # Return the discriminator model
    return Model(input_layer, disc_output)

# GAN setup

In [None]:
generator_filter = 32
discriminator_filters = 64
# image shape
image_height = 256
image_width = 256
# input shape
channels = 3
input_shape = (image_height, image_width, channels)
# loss weights
lambda_cycle = 10.0
lambda_identity = 0.1 * lambda_cycle
# optimizer
optimizer = Adam (learning_rate= 0.0002, beta_1= 0.5)

patch = int (image_height / 2**4)
patch_gan_shape = (patch, patch, 1)

# CycleGAN model

In [None]:
# discriminator models 
disc_A = build_discriminator(img_shape = input_shape, num_filters = discriminator_filters)
disc_A.compile(loss = 'mse',
optimizer = optimizer,
metrics = ['accuracy'])

disc_B = build_discriminator(img_shape = input_shape, num_filters = discriminator_filters)
disc_B.compile(loss = 'mse',
optimizer = optimizer,
metrics = ['accuracy'])

# generators model 
gen_AtoB = build_generator(img_shape = input_shape, in_channels = channels, num_filters = generator_filter)
gen_BtoA = build_generator(img_shape = input_shape, in_channels = channels, num_filters = generator_filter)

#CycleGAN model
real_image_A = Input(shape=input_shape)
real_image_B = Input(shape=input_shape)
# generate fake samples from both generators
fake_image_B = gen_AtoB(real_image_A)
fake_image_A = gen_BtoA(real_image_B)

# *****Reconstruction Loss*****
# reconstruct original samples from both generators using fake images 
reconstruct_A = gen_BtoA(fake_image_B) # it must be similar to real images from domain A
reconstruct_B = gen_AtoB(fake_image_A) # it must be similar to real images from domain B

# *****Identity Loss*****
# generate identity samples
identity_A = gen_BtoA(real_image_A) # it must be equal to real image from domain A
identity_B = gen_AtoB(real_image_B) # it must be equal to real image from domain B
# disable discriminator training
disc_A.trainable = False
disc_B.trainable = False

# *****Adversarial Loss*****
# use discriminator to classify real vs fake 
output_A = disc_A(fake_image_A)
output_B = disc_B(fake_image_B)
# Combined model trains generators to fool discriminators to fool discriminators
cycle_gan = Model(inputs= [real_image_A, real_image_B],
            outputs = [output_A, output_B, reconstruct_A, reconstruct_B, identity_A, identity_B])

cycle_gan.compile (loss = ['mse', 'mse', 'mae', 'mae', 'mae', 'mae'], # mse  is used for Adversarial losses while mae is used for identity and reconstruction losses
             loss_weights = [1, 1, lambda_cycle, lambda_cycle, lambda_identity, lambda_identity], # how losses are combined to get final loss value
             optimizer= optimizer # which optimizer is used
             ) 

# Training CycleGAN model

In [None]:
def trainig(gen_AtoB,
                gen_BtoA, 
                disc_A, 
                disc_B, 
                cyclegan, 
                patch_gan_shape, 
                epochs,
                path= '/dataset/{}'.format(dataset_name),
                batch_size = 1, 
                sample_interval = 50):
    # Adversarial loss ground truths
    real_labels = np.ones((batch_size,) + patch_gan_shape)
    fake_labels = np.zeros((batch_size,) + patch_gan_shape)
    
    for epoch in range(epochs):
        print(f'Epoch={epoch}')
        for idx, (imgs_A, imgs_B) in enumerate(batch_generator(path, batch_size, image_res=[image_height, image_width])) :
            # generate fake smaples from both generators
            fake_B = gen_AtoB.predict(imgs_A)
            fake_A = gen_BtoA.predict(imgs_B)
            
            # Train discriminators
            disc_A_loss_real = disc_A.train_on_batch(imgs_A, real_labels)
            disc_A_loss_fake = disc_A.train_on_batch(fake_A, fake_labels)
            disc_A_loss = 0.5 * np.add(disc_A_loss_real, disc_A_loss_fake)
            
            disc_B_loss_real = disc_B.train_on_batch(imgs_B, real_labels)
            disc_B_loss_fake = disc_B.train_on_batch(imgs_B, fake_labels)
            disc_B_loss = 0.5 * np.add(disc_B_loss_real, disc_B_loss_fake)
            # total discriminator loss
            discriminator_loss = 0.5 * np.add(disc_A_loss, disc_B_loss)
            
            # Train generator
            gen_loss = cycle_gan.train_on_batch([imgs_A, imgs_B],
                                                [
                                                 real_labels, real_labels, 
                                                 imgs_A, imgs_B,
                                                 imgs_A, imgs_B
                                                 ]
                                                )
            # training updates every 50 iterations
            if idx % 50 == 0:
                print(f'[Epoch {idx}/{epoch}] '
                        f'[Discriminator loss: {discriminator_loss[0]} Accuracy: {100 * discriminator_loss[1]:.2f}] '
                        f'[Adversarial loss (A to B): {gen_loss[0]}] '
                        f'[Adversarial loss (B to A): {gen_loss[1]}] '
                        f'[Reconstruction loss (A): {gen_loss[2]}] '
                        f'[Reconstruction loss (B): {gen_loss[3]}] '
                        f'[Identity loss (A): {gen_loss[4]}] '
                        f'[Identity loss (B): {gen_loss[5]}]')
            
            # plot and save progress every few iterations
            if idx % sample_interval == 0:
                plot_sample_images(gen_AtoB, 
                                   gen_BtoA,
                                   path=path,
                                   epoch = epoch,
                                   batch_num= idx,
                                   output_dir= 'images')
                            

            
            
            
            
            