# GAN for cats generation


<b>Author:</b> Przemyslaw Niedziela (przemyslaw.niedziela98@gmail.com) <br> 
<b>Date:</b> Jan 2025 <br>
<br> <br> 

TL;DR <br>
GAN (Generative Adversarial Network) to generate 64x64 cat faces images (based on this [dataset](https://www.kaggle.com/datasets/spandan2/cats-faces-64x64-for-generative-models/code)) using TensorFlow and Keras. The  architecture includes a generator to synthesize images from noise and a discriminator to classify real versus generated images, with training controlled by hyperparameters like batch size, noise dimension and epochs. 


In [None]:
import os
import warnings
from typing import Tuple, List, Union
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import imageio

warnings.filterwarnings('ignore')

In [18]:
BATCH_SIZE: int = 256
NOISE_DIM: int = 100
EPOCHS: int = 1000
BUFFER_SIZE: int = 60000

### Initial exploration

In [None]:
def preprocess_image(image: tf.Tensor) -> tf.Tensor:
    """
    Preprocess an image by resizing, reordering channels if necessary, and normalizing.

    Args:
        image (tf.Tensor): Input image tensor.

    Returns:
        tf.Tensor: Preprocessed image tensor.
    """
    image = tf.image.resize(image, [64, 64])
    channels = tf.shape(image)[-1]
    if channels > 3:
        image = image[:, :, :3]  
    
    image = (image - 127.5) / 127.5
    return image

def load_dataset(path: str) -> tf.data.Dataset:
    """
    Load and preprocess the dataset from the specified directory.

    Args:
        path (str): Path to the dataset directory.

    Returns:
        tf.data.Dataset: Preprocessed dataset ready for training.
    """
    dataset = tf.keras.utils.image_dataset_from_directory(
        path,
        label_mode=None,
        image_size=(64, 64),
        batch_size=None
    )
    dataset = dataset.filter(lambda x: tf.shape(x)[0] == 64 and tf.shape(x)[1] == 64 and tf.shape(x)[-1] == 3)
    dataset = dataset.map(preprocess_image)
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    return dataset

dataset = load_dataset("cats")

In [None]:
for image_batch in dataset.take(1):  
    plt.figure(figsize=(10, 10))
    for i in range(min(6, image_batch.shape[0])):  
        ax = plt.subplot(3, 3, i + 1)
        img = image_batch[i].numpy() 
        img = ((img * 127.5) + 127.5).astype(np.uint8) 
        plt.imshow(img)
        plt.axis("off")
    plt.show()

In [None]:
pixel_values = []
for image in dataset.unbatch():
    pixel_values.append(image.numpy().flatten())
pixel_values = tf.concat(pixel_values, axis=0)

print(f"Mean pixel value: {tf.reduce_mean(pixel_values).numpy():.4f}")
print(f"Std dev pixel value: {tf.math.reduce_std(pixel_values).numpy():.4f}")
print(f"Pixel value range: [{tf.reduce_min(pixel_values).numpy()}, {tf.reduce_max(pixel_values).numpy()}]")

In [None]:
channel_data = {0: [], 1: [], 2: []}

for image in dataset.unbatch().take(100): 
    for i in range(3):
        channel_data[i].extend(image[:, :, i].numpy().flatten())

plt.figure(figsize=(15, 5))
for i, channel in channel_data.items():
    plt.subplot(1, 3, i + 1)
    plt.hist(channel, bins=50, alpha=0.7, label=f'Channel {i+1}')
    plt.title(f'Pixel Intensity Distribution - Channel {i+1}')
    plt.xlabel('Pixel Intensity')
    plt.ylabel('Frequency')
    plt.legend()
plt.tight_layout()
plt.show()

### Model definition

In [7]:
class Generator(tf.keras.Model):
    """
    Generator model for the GAN.
    Generates images from random noise vectors.
    """
    def __init__(self):
        super(Generator, self).__init__()
        self.model = tf.keras.Sequential([
            layers.Dense(4*4*512, use_bias=False, input_shape=(100,)),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.Reshape((4, 4, 512)),

            layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.ReLU(),

            layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
        ])

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Forward pass for the generator.

        Args:
            inputs (tf.Tensor): Input noise vector of shape (batch_size, noise_dim).

        Returns:
            tf.Tensor: Generated images of shape (batch_size, 64, 64, 3).
        """
        return self.model(inputs)


In [8]:
class Discriminator(tf.keras.Model):
    """
    Discriminator model for the GAN.
    Classifies images as real or fake.
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = tf.keras.Sequential([
            layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[64, 64, 3]),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),

            layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),

            layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),

            layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),

            layers.Flatten(),
            layers.Dense(1, activation='sigmoid')
        ])

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Forward pass for the discriminator.

        Args:
            inputs (tf.Tensor): Input images of shape (batch_size, 64, 64, 3).

        Returns:
            tf.Tensor: Classification logits for each image.
        """
        return self.model(inputs)

In [9]:
generator = Generator()
discriminator = Discriminator()

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [10]:
def discriminator_loss(real_output: tf.Tensor, fake_output: tf.Tensor) -> tf.Tensor:
    """
    Compute the discriminator loss.

    Args:
        real_output (tf.Tensor): Discriminator predictions on real images.
        fake_output (tf.Tensor): Discriminator predictions on fake images.

    Returns:
        tf.Tensor: Total discriminator loss.
    """
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output: tf.Tensor) -> tf.Tensor:
    """
    Compute the generator loss.

    Args:
        fake_output (tf.Tensor): Discriminator predictions on generated images.

    Returns:
        tf.Tensor: Generator loss.
    """
    return cross_entropy(tf.ones_like(fake_output), fake_output)


In [11]:
def generate_and_save_images(
    model: tf.keras.Model, 
    epoch: int, 
    test_input: tf.Tensor, 
    save: bool = False
) -> None:
    """
    Generate and optionally save images using the generator model.

    Args:
        model (tf.keras.Model): The trained generator model.
        epoch (int): The current training epoch (used for saving images with the epoch number).
        test_input (tf.Tensor): Input tensor (e.g., random noise) to the generator for image generation.
        save (bool): If True, saves the generated images to disk with a filename indicating the epoch.
    """
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] + 1) / 2.0) 
        plt.axis('off')

    if save:
        plt.savefig(f'generated_cats/image_at_epoch_{epoch:04d}.png')
    plt.show()

In [12]:
@tf.function
def train_step(images: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
    """
    Perform a single training step for the generator and discriminator.

    Args:
        images (tf.Tensor): Batch of real images.

    Returns:
        tuple[tf.Tensor, tf.Tensor]: Generator and discriminator losses.
    """
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

def train(dataset: tf.data.Dataset, epochs: int) -> Tuple[List[float], List[float]]:
    """
    Train the GAN for a specified number of epochs.

    Args:
        dataset (tf.data.Dataset): Preprocessed dataset for training.
        epochs (int): Number of training epochs.

    Returns:
        Tuple[List[float], List[float]]: Generator and discriminator losses over all epochs.
    """
    losses_gen, losses_disc = [], [] 
    noise = tf.random.normal([16, NOISE_DIM])
    for epoch in range(epochs):
        epoch_gen_loss = []
        epoch_disc_loss = []

        for image_batch in dataset: 
            gen_loss, disc_loss = train_step(image_batch)
            epoch_gen_loss.append(gen_loss.numpy())
            epoch_disc_loss.append(disc_loss.numpy())

        avg_gen_loss = sum(epoch_gen_loss) / len(epoch_gen_loss)
        avg_disc_loss = sum(epoch_disc_loss) / len(epoch_disc_loss)

        losses_gen.append(avg_gen_loss)
        losses_disc.append(avg_disc_loss)

        generate_and_save_images(generator, epoch + 1, noise, True)

        print(f"Epoch {epoch + 1}, Gen Loss: {avg_gen_loss:.4f}, Disc Loss: {avg_disc_loss:.4f}")

    return losses_gen, losses_disc

In [None]:
losses_gen, losses_disc = train(dataset, EPOCHS)

In [None]:
def visualize_generated_images(generator: tf.keras.Model, num_images: int = 16):
    """
    Generate and visualize images from the GAN generator.

    Args:
        generator (tf.keras.Model): The generator model.
        num_images (int): Number of images to generate and display.
    """
    noise = tf.random.normal([num_images, NOISE_DIM])
    generated_images = generator(noise, training=False)
    generated_images = (generated_images + 1) / 2.0
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    axes = axes.flatten()

    for i in range(num_images):
        axes[i].imshow(generated_images[i])
        axes[i].axis('off')  
    plt.tight_layout()
    plt.show()

visualize_generated_images(generator, num_images=16)

In [None]:
def plot_losses(generator_losses: List[float], discriminator_losses: List[float]) -> None:
    """
    Plot the generator and discriminator losses over epochs.

    Args:
        generator_losses (List[float]): List of generator loss values recorded during training.
        discriminator_losses (List[float]): List of discriminator loss values recorded during training.

    Returns:
        None
    """
    plt.figure(figsize=(10, 5))
    plt.plot(generator_losses, label="Generator Loss")
    plt.plot(discriminator_losses, label="Discriminator Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Losses during GAN Training")
    plt.legend()
    plt.show()

plot_losses(losses_gen, losses_disc)

In [None]:
def visualize_latent_interpolation(generator: tf.keras.Model,steps: int = 10) -> None:
    """
    Visualize latent space interpolation by generating smooth transitions between images.

    Args:
        generator (tf.keras.Model): The generator model.
        noise_dim (int): The dimensionality of the noise vector input to the generator.
        steps (int): Number of interpolated steps between two random noise vectors.
    """
    noise_start = tf.random.normal([1, NOISE_DIM])
    noise_end = tf.random.normal([1, NOISE_DIM])
    interpolated_noise = [
        noise_start + (t / (steps - 1)) * (noise_end - noise_start) for t in range(steps)
    ]
    interpolated_noise = tf.concat(interpolated_noise, axis=0)

    generated_images = generator(interpolated_noise, training=False)
    generated_images = (generated_images + 1) / 2.0  

    fig, axes = plt.subplots(1, steps, figsize=(20, 5))
    for i, img in enumerate(generated_images):
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()


visualize_latent_interpolation(generator)

In [None]:
def visualize_image_diversity(generator: tf.keras.Model, num_images: int = 16) -> None:
    """
    Visualize the diversity of generated images by sampling multiple random noise vectors.

    Args:
        generator (tf.keras.Model): The generator model.
        num_images (int): Number of images to generate and display.
        noise_dim (int): Dimensionality of the input noise vector.
    """
    noise = tf.random.normal([num_images, NOISE_DIM])
    generated_images = generator(noise, training=False)
    generated_images = (generated_images + 1) / 2.0

    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    axes = axes.flatten()
    for i in range(num_images):
        axes[i].imshow(generated_images[i])
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

visualize_image_diversity(generator)

In [None]:
def create_gif_from_images(image_folder: str, gif_name: str, duration: float = 0.3) -> None:
    """
    Generate a GIF from saved images in a specified folder.

    Args:
        image_folder (str): Path to the folder containing the images.
        gif_name (str): Output name for the GIF file.
        duration (float): Duration for each frame in the GIF (in seconds).
    """
    images = sorted(
        [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith('.png')]
    )
    
    if not images:
        print(f"No images found in folder: {image_folder}")
        return
    
    frames = [imageio.imread(image) for image in images]
    gif_path = os.path.join(image_folder, gif_name)
    imageio.mimsave(gif_path, frames, duration=duration)

create_gif_from_images('generated_cats', 'cats_gan.gif')