In [1]:
import tensorflow as tf

2024-11-14 09:55:15.857744: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-14 09:55:16.000853: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731599716.052055    3300 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731599716.065931    3300 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-14 09:55:16.184709: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## CycleGAN

### Overview

This implementation provides a complete TensorFlow-based CycleGAN framework for image-to-image translation. The codebase includes data processing, model architecture, training pipeline, and evaluation metrics.


## Configuration


In [2]:
gpu_devices = tf.config.experimental.list_physical_devices("GPU")
if gpu_devices:
    for device in gpu_devices:
        tf.config.experimental.set_memory_growth(device, True)

    strategy = tf.distribute.MirroredStrategy()
    print(f"Number of devices: {strategy.num_replicas_in_sync}")

else:
    print("No GPU devices found.")

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Number of devices: 1


I0000 00:00:1731599719.789161    3300 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3586 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


## Model Parameters


In [3]:
from dataclasses import dataclass

In [4]:
@dataclass
class ModelConfig:
    height: int = 256
    width: int = 256
    channels: int = 3
    base_filters: int = 64
    lambda_cycle: float = 10.0
    lambda_identity: float = 0.5
    learning_rate: float = 2e-4
    beta_1: float = 0.5

    def get_config(self):
        return {
            "height": self.height,
            "width": self.width,
            "channels": self.channels,
            "base_filters": self.base_filters,
            "lambda_cycle": self.lambda_cycle,
            "lambda_identity": self.lambda_identity,
            "learning_rate": self.learning_rate,
            "beta_1": self.beta_1,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)

## Loading the data


In [5]:
class ImageProcessor:
    def __init__(self, config: ModelConfig):
        self.config = config

    def decode_image(self, image: tf.Tensor) -> tf.Tensor:
        """
        Decodes and normalizes the image.

        Args:
          image: A tensor representing an image.

        Returns:
          The decoded and normalized image.
        """

        image = tf.image.decode_jpeg(image, channels=self.config.channels)
        image = tf.cast(image, tf.float32)
        image = (image / 127.5) - 1

        return image

    @tf.function
    def parse_tfrecord(self, example_photo):
        """
        Parses a TFRecord example containing a photo.

        Args:
          example_photo: A TFRecord example containing a photo.

        Returns:
          The decoded and normalized photo.
        """

        feature_description = {
            "image_name": tf.io.FixedLenFeature([], tf.string),
            "image": tf.io.FixedLenFeature([], tf.string),
            "target": tf.io.FixedLenFeature([], tf.string),
        }

        example = tf.io.parse_single_example(example_photo, feature_description)
        return self.decode_image(example["image"])

    def create_dataset(
        self,
        filenames: list,
        batch_size: int = 1,
        shuffle: bool = True,
        cache: bool = True,
    ) -> tf.data.Dataset:
        """
        Creates a dataset from TFRecord files.

        Args:
            filenames (list): A list of TFRecord filenames.
            batch_size (int): The batch size.
            shuffle (bool): Whether to shuffle the dataset.
            cache (bool): Whether to cache the dataset in memory.

        Returns:
            A tf.data.Dataset instance.
        """

        dataset = tf.data.TFRecordDataset(
            filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE
        )

        dataset = dataset.map(
            self.parse_tfrecord, num_parallel_calls=tf.data.experimental.AUTOTUNE
        )

        if shuffle:
            dataset = dataset.shuffle(buffer_size=2048, reshuffle_each_iteration=True)

        dataset = dataset.batch(batch_size, drop_remainder=True)
        if cache:
            dataset = dataset.cache()

        return dataset.prefetch(tf.data.experimental.AUTOTUNE)

## Visualization Functions


In [6]:
import matplotlib.pyplot as plt
import numpy as np
from typing import Optional, Tuple

In [7]:
class VisualizationUtils:
    @staticmethod
    def denormalize_image(image: tf.Tensor) -> tf.Tensor:
        """
        Denormalizes the image by scaling it back to the [0, 255] range.

        Args:
            image: A tensor representing an image.

        Returns:
            The denormalized image.
        """

        return (image * 0.5) + 0.5

    @staticmethod
    def create_figure(
        nrows: int, ncols: int, figsize: Optional[Tuple[int, int]] = None
    ) -> Tuple[plt.Figure, plt.Axes]:
        """
        Create a Matplotlib figure and axes with the given number of rows and columns.

        Args:
            nrows (int): The number of rows.
            ncols (int): The number of columns.
            figsize (tuple): The figure size.

        Returns:
            A tuple containing the figure and axes.
        """

        if figsize is None:
            figsize = (ncols * 4, nrows * 4)

        fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

        if nrows * ncols == 1:
            axes = np.array([axes])

        return fig, axes.flatten()

## Model Building


In [8]:
import keras

In [9]:
@keras.saving.register_keras_serializable(package="CycleGAN", name="DiscriminatorBlock")
class DownsampleBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        filters: int,
        size: int = 4,
        strides: int = 2,
        apply_norm: bool = True,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.conv = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=size,
            strides=strides,
            padding="same",
            kernel_initializer=tf.keras.initializers.RandomNormal(
                mean=0.0, stddev=0.02
            ),
            use_bias=not apply_norm,
        )

        self.batch_norm = tf.keras.layers.BatchNormalization(
            gamma_initializer=(
                tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
                if apply_norm
                else None
            )
        )

        self.activation = tf.keras.layers.LeakyReLU(negative_slope=0.2)

    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv(x)
        if self.batch_norm is not None:
            x = self.batch_norm(x, training=training)

        return self.activation(x)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "filters": self.filters,
                "size": self.size,
                "strides": self.strides,
                "apply_norm": self.apply_norm,
            }
        )

        return config

In [10]:
@keras.saving.register_keras_serializable(package="CycleGAN", name="UpsampleBlock")
class UpsampleBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        filters: int,
        size: int = 4,
        strides: int = 2,
        apply_dropout: bool = False,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.conv_transpose = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=size,
            strides=strides,
            padding="same",
            kernel_initializer=tf.keras.initializers.RandomNormal(
                mean=0.0, stddev=0.02
            ),
            use_bias=False,
        )

        self.batch_norm = tf.keras.layers.BatchNormalization(
            gamma_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
        )

        self.dropout = tf.keras.layers.Dropout(0.5) if apply_dropout else None
        self.activation = tf.keras.layers.ReLU()

    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.conv_transpose(x)
        x = self.batch_norm(x, training=training)

        if self.dropout is not None:
            x = self.dropout(x, training=training)

        return self.activation(x)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "filters": self.filters,
                "size": self.size,
                "strides": self.strides,
                "apply_dropout": self.apply_dropout,
            }
        )

        return config

## Generator Model


In [11]:
@keras.saving.register_keras_serializable(package="CycleGAN", name="Generator")
class Generator(tf.keras.Model):
    def __init__(self, config: ModelConfig, name: str = "generator", **kwargs):
        super().__init__(name=name, **kwargs)
        if isinstance(config, dict):
            config = ModelConfig(**config)

        self.config = config
        self.model_config = config

        self.downsample_stack = [
            DownsampleBlock(64, 4, apply_norm=False),
            DownsampleBlock(128, 4),
            DownsampleBlock(256, 4),
            DownsampleBlock(512, 4),
            DownsampleBlock(512, 4),
            DownsampleBlock(512, 4),
            DownsampleBlock(512, 4),
            DownsampleBlock(512, 4),
        ]

        self.upsample_stack = [
            UpsampleBlock(512, 4, apply_dropout=True),
            UpsampleBlock(512, 4, apply_dropout=True),
            UpsampleBlock(512, 4, apply_dropout=True),
            UpsampleBlock(512, 4),
            UpsampleBlock(256, 4),
            UpsampleBlock(128, 4),
            UpsampleBlock(64, 4),
        ]

        self.final_conv = tf.keras.layers.Conv2DTranspose(
            filters=config.channels,
            kernel_size=4,
            strides=2,
            padding="same",
            kernel_initializer=tf.keras.initializers.RandomNormal(
                mean=0.0, stddev=0.02
            ),
            activation="tanh",
        )

    def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
        skips = []
        for down in self.downsample_stack:
            x = down(x, training=training)
            skips.append(x)

        skips = reversed(skips[:-1])

        for up, skip in zip(self.upsample_stack, skips):
            x = up(x, training=training)
            x = tf.keras.layers.Concatenate()([x, skip])

        return self.final_conv(x)

    def get_config(self):
        return {"config": self.config.get_config(), "name": self.name}

    @classmethod
    def from_config(cls, config):
        return cls(
            config=ModelConfig.from_config(config["config"]), name=config["name"]
        )

In [12]:
@keras.saving.register_keras_serializable(package="CycleGAN", name="Discriminator")
class Discriminator(tf.keras.Model):
    """Modern implementation of PatchGAN discriminator."""

    def __init__(self, config: ModelConfig, name: str = "discriminator", **kwargs):
        super().__init__(name=name, **kwargs)
        self.config = config

        self.down_stack = [
            DownsampleBlock(config.base_filters, apply_norm=False),
            DownsampleBlock(config.base_filters * 2),
            DownsampleBlock(config.base_filters * 4),
        ]

        self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
        self.conv = tf.keras.layers.Conv2D(
            config.base_filters * 8,
            4,
            strides=1,
            kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.02),
            use_bias=False,
        )
        self.batch_norm = tf.keras.layers.BatchNormalization(
            gamma_initializer=tf.keras.initializers.RandomNormal(0.0, 0.02)
        )
        self.leaky_relu = tf.keras.layers.LeakyReLU(0.2)
        self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
        self.final_conv = tf.keras.layers.Conv2D(
            1,
            4,
            strides=1,
            kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.02),
        )

    def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
        for down in self.down_stack:
            x = down(x, training=training)

        x = self.zero_pad1(x)
        x = self.conv(x)
        x = self.batch_norm(x, training=training)
        x = self.leaky_relu(x)
        x = self.zero_pad2(x)

        return self.final_conv(x)

    def get_config(self):
        return {"config": self.config.get_config(), "name": self.name}

    @classmethod
    def from_config(cls, config):
        return cls(
            config=ModelConfig.from_config(config["config"]), name=config["name"]
        )

## Gan Model


In [13]:
from typing import Tuple, Dict, Optional
from pathlib import Path

In [14]:
@keras.saving.register_keras_serializable(package="CycleGAN", name="CycleGAN")
class CycleGAN(tf.keras.Model):
    def __init__(self, config: ModelConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config

        # Generators
        self.gen_G = Generator(config, name="generator_G")
        self.gen_F = Generator(config, name="generator_F")

        # Discriminators
        self.disc_X = Discriminator(config, name="discriminator_X")
        self.disc_Y = Discriminator(config, name="discriminator_Y")

        # Optimizers with same settings
        optimizer_kwargs = dict(
            learning_rate=config.learning_rate, beta_1=config.beta_1
        )
        self.gen_G_optimizer = tf.keras.optimizers.Adam(**optimizer_kwargs)
        self.gen_F_optimizer = tf.keras.optimizers.Adam(**optimizer_kwargs)
        self.disc_X_optimizer = tf.keras.optimizers.Adam(**optimizer_kwargs)
        self.disc_Y_optimizer = tf.keras.optimizers.Adam(**optimizer_kwargs)

        # Loss trackers
        self.gen_G_loss_tracker = tf.keras.metrics.Mean(name="gen_G_loss")
        self.gen_F_loss_tracker = tf.keras.metrics.Mean(name="gen_F_loss")
        self.disc_X_loss_tracker = tf.keras.metrics.Mean(name="disc_X_loss")
        self.disc_Y_loss_tracker = tf.keras.metrics.Mean(name="disc_Y_loss")
        self.cycle_loss_tracker = tf.keras.metrics.Mean(name="cycle_loss")
        self.identity_loss_tracker = tf.keras.metrics.Mean(name="identity_loss")

    def compile(self, **kwargs):
        super().compile(**kwargs)

    def call(self, inputs, training=False):
        if isinstance(inputs, tuple):
            real_x, real_y = inputs

            if training:
                fake_y = self.gen_G(real_x, training=training)
                fake_x = self.gen_F(real_y, training=training)

                return fake_y, fake_x

            else:
                return self.gen_G(real_x, training=training)

        return self.gen_G(inputs, training=training)

    def _generator_loss(self, disc_generated_output):
        """
        Calculates the generator loss using the discriminator output.

        Args:
            disc_generated_output: Discriminator output.

        Returns:
            The generator loss.
        """

        return tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(
                tf.ones_like(disc_generated_output),
                disc_generated_output,
                from_logits=True,
            )
        )

    def _discriminator_loss(self, real_output, fake_output):
        """
        Calculates the discriminator loss.

        Args:
            real_output: Real output.
            fake_output: Fake output.

        Returns:
            The discriminator loss.
        """

        real_loss = tf.keras.losses.binary_crossentropy(
            tf.ones_like(real_output), real_output, from_logits=True
        )

        fake_loss = tf.keras.losses.binary_crossentropy(
            tf.zeros_like(fake_output), fake_output, from_logits=True
        )

        return tf.reduce_mean(real_loss + fake_loss) * 0.5

    def _cycle_loss(self, real_image, cycled_image):
        """
        Calculates the consistency loss using the L1 norm.

        Args:
            real_image: Real image.
            cycled_image: Cycled image.

        Returns:
            The cycle consistency loss.
        """

        return tf.reduce_mean(tf.abs(real_image - cycled_image))

    def _identity_loss(self, real_image, same_image):
        """
        Calculates the identity loss using the L1 norm.

        Args:
            real_image: Real image.
            same_image: Image after identity mapping.

        Returns:
            The identity loss.
        """
        return tf.reduce_mean(tf.abs(real_image - same_image))

    def train_step(self, batch_data):
        if isinstance(batch_data, tuple):
            real_x = batch_data[0]
            real_y = batch_data[1]

        else:
            raise ValueError(
                "Expected tuple of two tensors, got: {}".format(batch_data)
            )

        if real_x.shape != real_y.shape:
            raise ValueError(f"Shape mismatch: {real_x.shape} vs {real_y.shape}")

        with tf.GradientTape(persistent=True) as tape:
            # Generator outputs
            fake_y = self.gen_G(real_x, training=True)
            fake_x = self.gen_F(real_y, training=True)

            # Cycle consistency
            cycled_x = self.gen_F(fake_y, training=True)
            cycled_y = self.gen_G(fake_x, training=True)

            # Identity mapping
            same_x = self.gen_F(real_x, training=True)
            same_y = self.gen_G(real_y, training=True)

            # Discriminator outputs
            disc_real_x = self.disc_X(real_x, training=True)
            disc_fake_x = self.disc_X(fake_x, training=True)
            disc_real_y = self.disc_Y(real_y, training=True)
            disc_fake_y = self.disc_Y(fake_y, training=True)

            # Generator losses
            gen_G_loss = self._generator_loss(disc_fake_y)
            gen_F_loss = self._generator_loss(disc_fake_x)

            # Cycle consistency loss
            cycle_loss = (
                self._cycle_loss(real_x, cycled_x) + self._cycle_loss(real_y, cycled_y)
            ) * self.config.lambda_cycle

            # Identity loss
            identity_loss = (
                self._identity_loss(real_x, same_x)
                + self._identity_loss(real_y, same_y)
            ) * self.config.lambda_identity

            # Total generator losses
            total_gen_G_loss = gen_G_loss + cycle_loss + identity_loss
            total_gen_F_loss = gen_F_loss + cycle_loss + identity_loss

            # Discriminator losses
            disc_X_loss = self._discriminator_loss(disc_real_x, disc_fake_x)
            disc_Y_loss = self._discriminator_loss(disc_real_y, disc_fake_y)

        # Calculate and apply gradients
        gen_G_gradients = tape.gradient(
            total_gen_G_loss, self.gen_G.trainable_variables
        )
        gen_F_gradients = tape.gradient(
            total_gen_F_loss, self.gen_F.trainable_variables
        )
        disc_X_gradients = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
        disc_Y_gradients = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)

        # Apply gradients
        self.gen_G_optimizer.apply_gradients(
            zip(gen_G_gradients, self.gen_G.trainable_variables)
        )
        self.gen_F_optimizer.apply_gradients(
            zip(gen_F_gradients, self.gen_F.trainable_variables)
        )
        self.disc_X_optimizer.apply_gradients(
            zip(disc_X_gradients, self.disc_X.trainable_variables)
        )
        self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_gradients, self.disc_Y.trainable_variables)
        )

        # Update metrics
        self.gen_G_loss_tracker.update_state(total_gen_G_loss)
        self.gen_F_loss_tracker.update_state(total_gen_F_loss)
        self.disc_X_loss_tracker.update_state(disc_X_loss)
        self.disc_Y_loss_tracker.update_state(disc_Y_loss)
        self.cycle_loss_tracker.update_state(cycle_loss)
        self.identity_loss_tracker.update_state(identity_loss)

        return {
            "gen_G_loss": self.gen_G_loss_tracker.result(),
            "gen_F_loss": self.gen_F_loss_tracker.result(),
            "disc_X_loss": self.disc_X_loss_tracker.result(),
            "disc_Y_loss": self.disc_Y_loss_tracker.result(),
            "cycle_loss": self.cycle_loss_tracker.result(),
            "identity_loss": self.identity_loss_tracker.result(),
        }

    @property
    def metrics(self) -> list:
        """
        Returns the model's metrics.

        Returns:
            A list of metrics.
        """

        return [
            self.gen_G_loss_tracker,
            self.gen_F_loss_tracker,
            self.disc_X_loss_tracker,
            self.disc_Y_loss_tracker,
            self.cycle_loss_tracker,
            self.identity_loss_tracker,
        ]

    def get_config(self):
        return {"config": self.config.get_config()}

    @classmethod
    def from_config(cls, config):
        return cls(config=ModelConfig.from_config(config["config"]))

## Evaluator


In [15]:
from skimage.metrics import structural_similarity
from scipy import linalg

In [16]:
class CycleGANEvaluator:
    def __init__(self):
        """Initialize the CycleGAN evaluator with InceptionV3 model."""
        self.inception_model = tf.keras.applications.InceptionV3(
            include_top=False, weights="imagenet", pooling="avg"
        )

    def preprocess_images(self, images):
        """Preprocess images to the correct format."""
        # Ensure we're working with float32
        images = tf.cast(images, tf.float32)

        # Ensure pixel values are in [0, 1]
        if tf.reduce_max(images) > 1.0:
            images = images / 255.0

        return images

    def calculate_fid(self, real_images, generated_images):
        """
        Calculate Frechet Inception Distance between real and generated images.

        Args:
            real_images: Tensor of real images
            generated_images: Tensor of generated images

        Returns:
            FID score (float)
        """
        # Preprocess images
        real_images = self.preprocess_images(real_images)
        generated_images = self.preprocess_images(generated_images)

        # Resize images to InceptionV3 input size
        real_images = tf.image.resize(real_images, (299, 299))
        generated_images = tf.image.resize(generated_images, (299, 299))

        # Ensure batch processing for large datasets
        batch_size = 32

        def get_features(images):
            features_list = []
            for i in range(0, len(images), batch_size):
                batch = images[i : i + batch_size]
                features = self.inception_model.predict(batch, verbose=0)
                features_list.append(features)
            return np.concatenate(features_list, axis=0)

        # Get features
        real_features = get_features(real_images)
        gen_features = get_features(generated_images)

        # Calculate mean and covariance
        mu1, sigma1 = np.mean(real_features, axis=0), np.cov(
            real_features, rowvar=False
        )
        mu2, sigma2 = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)

        # Calculate FID
        ssdiff = np.sum((mu1 - mu2) ** 2.0)
        covmean = linalg.sqrtm(sigma1.dot(sigma2), disp=False)[0]

        if np.iscomplexobj(covmean):
            covmean = covmean.real

        return float(ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean))

    def calculate_ssim(self, real_images, generated_images):
        """Calculate SSIM between real and generated images."""
        real_images = self.preprocess_images(real_images)
        generated_images = self.preprocess_images(generated_images)

        ssim_scores = []
        for real, gen in zip(real_images, generated_images):
            # Convert to numpy and ensure correct shape
            real_np = real.numpy()
            gen_np = gen.numpy()

            # Handle grayscale images
            if len(real_np.shape) == 2:
                real_np = real_np[..., np.newaxis]
                gen_np = gen_np[..., np.newaxis]

            # Use a smaller window size (5x5) and explicitly specify channel_axis
            score = structural_similarity(
                real_np,
                gen_np,
                win_size=5,  # Smaller window size
                channel_axis=-1,  # Specify channel axis
                data_range=1.0,
            )
            ssim_scores.append(score)

        return float(np.mean(ssim_scores))

    def calculate_mse(self, real_images, generated_images):
        """Calculate MSE between real and generated images."""
        real_images = self.preprocess_images(real_images)
        generated_images = self.preprocess_images(generated_images)
        return float(tf.reduce_mean(tf.square(real_images - generated_images)))

    def calculate_psnr(self, real_images, generated_images):
        """Calculate PSNR between real and generated images."""
        mse = self.calculate_mse(real_images, generated_images)
        if mse == 0:
            return float("inf")
        return float(20 * np.log10(1.0 / tf.sqrt(mse)))

    def evaluate_model(self, model, test_dataset, num_samples=100):
        """
        Comprehensively evaluate the GAN model using multiple metrics.

        Args:
            model: CycleGAN model with gen_G attribute
            test_dataset: TensorFlow dataset containing test data
            num_samples: Number of samples to evaluate

        Returns:
            Dictionary containing evaluation metrics
        """
        # Fix the dataset caching warning by proper ordering of operations
        if num_samples is not None:
            test_dataset = test_dataset.take(num_samples).cache()
        else:
            test_dataset = test_dataset.cache()

        real_images = []
        generated_images = []

        try:
            # Collect real and generated images
            for batch in test_dataset:
                if isinstance(batch, tuple):
                    real_x = batch[0]
                else:
                    real_x = batch

                # Ensure 4D shape (batch_size, height, width, channels)
                if len(real_x.shape) != 4:
                    raise ValueError(f"Expected 4D tensor, got shape: {real_x.shape}")

                generated = model.gen_G(real_x, training=False)

                # Ensure consistent shapes
                if generated.shape != real_x.shape:
                    raise ValueError(
                        f"Shape mismatch: real {real_x.shape} vs generated {generated.shape}"
                    )

                real_images.append(real_x)
                generated_images.append(generated)

            # Convert to tensors
            real_images = tf.concat(real_images, axis=0)
            generated_images = tf.concat(generated_images, axis=0)

            print(f"Evaluating {len(real_images)} image pairs...")

            # Calculate all metrics
            metrics = {
                "fid": self.calculate_fid(real_images, generated_images),
                "ssim": self.calculate_ssim(real_images, generated_images),
                "mse": self.calculate_mse(real_images, generated_images),
                "psnr": self.calculate_psnr(real_images, generated_images),
            }

            return metrics

        except Exception as e:
            print("Error during evaluation:")
            print(f"Original error: {str(e)}")
            if len(real_images) > 0:
                print("\nTensor shapes in the batch where error occurred:")
                print(f"Real images shape: {real_images[-1].shape}")
                print(f"Generated images shape: {generated_images[-1].shape}")
            raise

In [17]:
from datetime import datetime

In [18]:
class MetricsVisualizer:
    def __init__(self, save_dir="./evaluation_results"):
        """
        Initialize the metrics visualizer.

        Args:
            save_dir: Directory to save the plots
        """
        self.save_dir = save_dir
        tf.io.gfile.makedirs(save_dir)

    def plot_image_comparison(self, real_images, generated_images, num_samples=5):
        """Plot a grid of real vs generated images."""
        plt.figure(figsize=(15, 6))

        # Randomly sample image pairs
        total_images = len(real_images)
        indices = np.random.choice(total_images, num_samples, replace=False)

        for idx, i in enumerate(indices):
            # Plot real image
            plt.subplot(2, num_samples, idx + 1)
            plt.imshow(self._prepare_image_for_plot(real_images[i]))
            plt.axis("off")
            if idx == 0:
                plt.title("Real Images")

            # Plot generated image
            plt.subplot(2, num_samples, num_samples + idx + 1)
            plt.imshow(self._prepare_image_for_plot(generated_images[i]))
            plt.axis("off")
            if idx == 0:
                plt.title("Generated Images")

        plt.tight_layout()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(f"{self.save_dir}/image_comparison_{timestamp}.png")
        plt.close()

    def plot_metrics_history(self, metrics_history):
        """Plot the evolution of metrics over time."""
        plt.figure(figsize=(15, 10))

        # Create subplots for each metric
        metrics = list(metrics_history[0].keys())
        num_metrics = len(metrics)
        rows = (num_metrics + 1) // 2

        for idx, metric in enumerate(metrics, 1):
            plt.subplot(rows, 2, idx)
            values = [h[metric] for h in metrics_history]
            epochs = range(1, len(values) + 1)

            plt.plot(epochs, values, "b-", marker="o")
            plt.title(f"{metric.upper()} over Time")
            plt.xlabel("Evaluation Step")
            plt.ylabel(metric.upper())
            plt.grid(True)

        plt.tight_layout()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(f"{self.save_dir}/metrics_history_{timestamp}.png")
        plt.close()

    def plot_metrics_distribution(self, real_images, generated_images, evaluator):
        """Plot the distribution of metrics across individual image pairs."""
        plt.figure(figsize=(15, 10))

        # Calculate metrics for each image pair
        ssim_scores = []
        mse_scores = []
        psnr_scores = []

        for real, gen in zip(real_images, generated_images):
            real_batch = tf.expand_dims(real, 0)
            gen_batch = tf.expand_dims(gen, 0)

            ssim = evaluator.calculate_ssim(real_batch, gen_batch)
            mse = evaluator.calculate_mse(real_batch, gen_batch)
            psnr = evaluator.calculate_psnr(real_batch, gen_batch)

            ssim_scores.append(ssim)
            mse_scores.append(mse)
            psnr_scores.append(psnr)

        # Plot distributions
        plt.subplot(2, 2, 1)
        plt.hist(ssim_scores, bins=30, color="blue", alpha=0.7)
        plt.title("SSIM Distribution")
        plt.xlabel("SSIM Score")
        plt.ylabel("Count")

        plt.subplot(2, 2, 2)
        plt.hist(mse_scores, bins=30, color="red", alpha=0.7)
        plt.title("MSE Distribution")
        plt.xlabel("MSE Score")
        plt.ylabel("Count")

        plt.subplot(2, 2, 3)
        plt.hist(psnr_scores, bins=30, color="green", alpha=0.7)
        plt.title("PSNR Distribution")
        plt.xlabel("PSNR Score")
        plt.ylabel("Count")

        # Add summary statistics
        plt.subplot(2, 2, 4)
        plt.axis("off")
        summary_text = (
            f"Summary Statistics:\n\n"
            f"SSIM:\n"
            f"  Mean: {np.mean(ssim_scores):.4f}\n"
            f"  Std: {np.std(ssim_scores):.4f}\n\n"
            f"MSE:\n"
            f"  Mean: {np.mean(mse_scores):.4f}\n"
            f"  Std: {np.std(mse_scores):.4f}\n\n"
            f"PSNR:\n"
            f"  Mean: {np.mean(psnr_scores):.4f}\n"
            f"  Std: {np.std(psnr_scores):.4f}"
        )
        plt.text(0.1, 0.1, summary_text, fontsize=10, fontfamily="monospace")

        plt.tight_layout()
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        plt.savefig(f"{self.save_dir}/metrics_distribution_{timestamp}.png")
        plt.close()

    def _prepare_image_for_plot(self, image):
        """Prepare image for plotting."""
        # Convert to numpy array if tensor
        if isinstance(image, tf.Tensor):
            image = image.numpy()

        # Ensure values are in [0, 1]
        if image.max() > 1.0:
            image = image / 255.0

        # Clip values to [0, 1]
        image = np.clip(image, 0, 1)

        return image

In [19]:
def plot_evaluation_results(
    evaluator, model, test_dataset, num_samples=100, metrics_history=None
):
    """
    Comprehensive plotting of evaluation results.

    Args:
        evaluator: CycleGANEvaluator instance
        model: CycleGAN model
        test_dataset: Test dataset
        num_samples: Number of samples to evaluate
        metrics_history: List of metrics dictionaries over time (optional)
    """
    visualizer = MetricsVisualizer()

    # Collect images and calculate metrics
    real_images = []
    generated_images = []

    # Take and cache the dataset
    dataset = test_dataset.take(num_samples).cache()

    for batch in dataset:
        if isinstance(batch, tuple):
            real_x = batch[0]
        else:
            real_x = batch

        generated = model.gen_G(real_x, training=False)
        real_images.extend(tf.unstack(real_x))
        generated_images.extend(tf.unstack(generated))

    # Convert to tensors
    real_images = tf.stack(real_images)
    generated_images = tf.stack(generated_images)

    # Plot image comparisons
    visualizer.plot_image_comparison(real_images, generated_images)

    # Plot metrics distribution
    visualizer.plot_metrics_distribution(real_images, generated_images, evaluator)

    # Plot metrics history if provided
    if metrics_history is not None:
        visualizer.plot_metrics_history(metrics_history)

    print(f"Plots have been saved to {visualizer.save_dir}")

### Model Visualization


In [20]:
from typing import Union

In [21]:
class CycleGANVisualizer:
    """Visualization tools for CycleGAN results."""

    def __init__(self, model: "CycleGAN"):
        self.model = model
        self.utils = VisualizationUtils()

    def display_samples(
        self,
        dataset: tf.data.Dataset,
        num_samples: Tuple[int, int] = (4, 6),
        figsize: Optional[Tuple[int, int]] = None,
        save_path: Optional[str] = None,
    ) -> None:
        """
        Display or save a grid of samples from the dataset.

        Args:
            dataset: TensorFlow dataset containing the images
            num_samples: Tuple of (rows, columns) for the display grid
            figsize: Optional tuple for figure size
            save_path: Optional path to save the figure
        """
        rows, cols = num_samples
        fig, axes = self.utils.create_figure(rows, cols, figsize)

        samples = next(iter(dataset.take(1).batch(rows * cols)))

        for idx, (ax, img) in enumerate(zip(axes, samples)):
            if isinstance(img, tuple):
                img = img[0]

            # Display image
            ax.imshow(self.utils.denormalize_image(img))
            ax.axis("off")
            ax.set_title(f"Sample {idx + 1}")

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches="tight", dpi=300)
            plt.close()
        else:
            plt.show()

    def display_generated_samples(
        self,
        test_dataset: tf.data.Dataset,
        num_samples: int = 3,
        figsize: Optional[Tuple[int, int]] = None,
        save_path: Optional[str] = None,
    ) -> None:
        """
        Display or save pairs of input and generated images.

        Args:
            test_dataset: Dataset containing test images
            num_samples: Number of sample pairs to display
            figsize: Optional tuple for figure size
            save_path: Optional path to save the figure
        """
        if figsize is None:
            figsize = (8 * 2, 4 * num_samples)

        fig, axes = plt.subplots(num_samples, 2, figsize=figsize)
        if num_samples == 1:
            axes = axes.reshape(1, -1)

        for idx, samples in enumerate(test_dataset.take(num_samples)):
            if isinstance(samples, tuple):
                input_image = samples[0]
            else:
                input_image = samples

            # Generate image
            generated_image = self.model.gen_G(input_image, training=False)

            axes[idx, 0].imshow(self.utils.denormalize_image(input_image[0]))
            axes[idx, 0].axis("off")
            axes[idx, 0].set_title("Input Photo")

            axes[idx, 1].imshow(self.utils.denormalize_image(generated_image[0]))
            axes[idx, 1].axis("off")
            axes[idx, 1].set_title("Generated Monet")

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches="tight", dpi=300)
            plt.close()
        else:
            plt.show()

    def display_progress(
        self,
        test_image: Union[str, tf.Tensor],
        epoch: int,
        save_dir: str = "training_progress",
    ) -> None:
        """
        Generate and save a Monet-style image to track training progress.

        Args:
            test_image: Path to image or tensor of test image
            epoch: Current epoch number
            save_dir: Directory to save progress images
        """
        import os

        os.makedirs(save_dir, exist_ok=True)

        if isinstance(test_image, str):
            img = tf.keras.preprocessing.image.load_img(
                test_image,
                target_size=(self.model.config.height, self.model.config.width),
            )

            img = tf.keras.preprocessing.image.img_to_array(img)
            img = (img / 127.5) - 1.0
            img = tf.expand_dims(img, 0)

        else:
            img = test_image

        generated = self.model.gen_G(img, training=False)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

        ax1.imshow(self.utils.denormalize_image(img[0]))
        ax1.axis("off")
        ax1.set_title("Original Photo")

        ax2.imshow(self.utils.denormalize_image(generated[0]))
        ax2.axis("off")
        ax2.set_title(f"Generated Monet (Epoch {epoch})")

        plt.tight_layout()
        plt.savefig(
            os.path.join(save_dir, f"progress_epoch_{epoch}.png"),
            bbox_inches="tight",
            dpi=300,
        )

        plt.close()

## Main Execution


In [22]:
def setup_training(base_dir: str = ".", batch_size: int = 1):
    """
    Setups the training data and configuration.

    Args:
        base_dir (str): The base directory containing the data.
        batch_size (int): The batch size.

    Returns:
        A tuple containing the configuration, training dataset, test dataset, and steps per epoch.
    """

    config = ModelConfig(
        height=256,
        width=256,
        channels=3,
        base_filters=64,
        lambda_cycle=10.0,
        lambda_identity=0.5,
        learning_rate=2e-4,
        beta_1=0.5,
    )

    data_dir = Path(base_dir) / "data"
    monet_files = tf.io.gfile.glob(str(data_dir / "monet_tfrec" / "*.tfrec"))
    photo_files = tf.io.gfile.glob(str(data_dir / "photo_tfrec" / "*.tfrec"))

    if not monet_files or not photo_files:
        raise ValueError("No TFRecord files found in ", data_dir)

    n_monet = sum(1 for f in monet_files for _ in tf.data.TFRecordDataset(f))
    n_photo = sum(1 for f in photo_files for _ in tf.data.TFRecordDataset(f))
    print(f"Found {n_monet} Monet images and {n_photo} photos")

    processor = ImageProcessor(config)
    monet_ds = processor.create_dataset(
        monet_files, batch_size=batch_size, shuffle=True, cache=True
    )

    photo_ds = processor.create_dataset(
        photo_files, batch_size=batch_size, shuffle=True, cache=True
    )

    test_ds = processor.create_dataset(
        photo_files[:10],
        batch_size=1,
        shuffle=False,
        cache=True,
    )

    train_ds = tf.data.Dataset.zip((photo_ds, monet_ds))
    steps_per_epoch = min(n_monet, n_photo) // batch_size

    return config, train_ds, test_ds, steps_per_epoch

In [23]:
from datetime import datetime

In [24]:
def setup_model_and_training(config: ModelConfig, strategy):
    """
    Setups the CycleGAN model and training configuration.

    Args:
        config (ModelConfig): The model configuration.
        strategy: The distribution strategy.

    Returns:
        A tuple containing the model and training callbacks.
    """

    Path("logs/cyclegan").mkdir(parents=True, exist_ok=True)
    Path("checkpoints/cyclegan").mkdir(parents=True, exist_ok=True)

    with strategy.scope():
        model = CycleGAN(config)
        model.compile()

    callbacks = [
        tf.keras.callbacks.TensorBoard(
            log_dir=f"logs/cyclegan/{datetime.now().strftime('%Y%m%d-%H%M%S')}",
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filepath="checkpoints/cyclegan/model.{epoch:03d}.weights.h5",
            save_freq="epoch",
            save_weights_only=True,
            monitor="gen_G_loss",
            mode="min",
            save_best_only=True,
        ),
    ]

    return model, callbacks

In [25]:
EPOCHS = 100
BATCH_SIZE = 1

In [26]:
config, train_ds, test_ds, steps_per_epoch = setup_training(
    base_dir="../", batch_size=BATCH_SIZE
)

2024-11-14 09:55:20.567020: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144
2024-11-14 09:55:20.576787: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-11-14 09:55:20.595942: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-11-14 09:55:20.632646: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-11-14 09:55:20.754711: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-11-14 09:55:21.024572: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Found 300 Monet images and 7038 photos


In [27]:
model, callbacks = setup_model_and_training(config, strategy)

In [28]:
history = model.fit(
    train_ds,
    epochs=2,
    steps_per_epoch=steps_per_epoch,
    callbacks=callbacks,
)



Epoch 1/2


E0000 00:00:1731599751.726755    3300 meta_optimizer.cc:966] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/generator_G_5/upsample_block_1/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
I0000 00:00:1731599755.857018   13866 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m295/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m1s[0m 361ms/step - cycle_loss: 7.0293 - disc_X_loss: 0.6110 - disc_Y_loss: 0.6235 - gen_F_loss: 8.1650 - gen_G_loss: 8.1606 - identity_loss: 0.3535

2024-11-14 09:57:49.789822: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


[1m299/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 361ms/step - cycle_loss: 7.0133 - disc_X_loss: 0.6107 - disc_Y_loss: 0.6237 - gen_F_loss: 8.1489 - gen_G_loss: 8.1436 - identity_loss: 0.3526

  self.gen.throw(typ, value, traceback)


[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m150s[0m 367ms/step - cycle_loss: 7.0093 - disc_X_loss: 0.6106 - disc_Y_loss: 0.6238 - gen_F_loss: 8.1448 - gen_G_loss: 8.1394 - identity_loss: 0.3524
Epoch 2/2
[1m296/300[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m1s[0m 408ms/step - cycle_loss: 5.2181 - disc_X_loss: 0.4439 - disc_Y_loss: 0.6240 - gen_F_loss: 6.6434 - gen_G_loss: 6.2898 - identity_loss: 0.2529

2024-11-14 09:59:54.620337: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 414ms/step - cycle_loss: 5.2159 - disc_X_loss: 0.4439 - disc_Y_loss: 0.6237 - gen_F_loss: 6.6423 - gen_G_loss: 6.2883 - identity_loss: 0.2528


In [29]:
visualizer = CycleGANVisualizer(model)

In [30]:
visualizer.display_generated_samples(
    test_ds, num_samples=8, save_path="generated_samples.png"
)

2024-11-14 09:59:58.765759: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [31]:
evaluator = CycleGANEvaluator()

In [32]:
metrics_history = []

In [33]:
try:
    metrics = evaluator.evaluate_model(model, test_ds, num_samples=100)
    metrics_history.append(metrics)

    plot_evaluation_results(
        evaluator=evaluator,
        model=model,
        test_dataset=test_ds,
        num_samples=100,
        metrics_history=metrics_history,
    )

    for metric_name, value in metrics.items():
        print(f"{metric_name.capitalize()}: {value:.4f}")

except Exception as e:
    print(f"Error during evaluation: {str(e)}")

2024-11-14 10:00:07.033483: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-11-14 10:00:07.067546: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Evaluating 100 image pairs...


I0000 00:00:1731600008.442807   13865 service.cc:148] XLA service 0x7facdc219d30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731600008.443264   13865 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3060 Laptop GPU, Compute Capability 8.6
2024-11-14 10:00:08.569956: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-11-14 10:00:11.430597: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 593.83MiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
I0000 00:00:1731600019.960543   13865 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2024-11-14 10:00:42.081051: W t

Plots have been saved to ./evaluation_results
Fid: 117.6134
Ssim: 0.3047
Mse: 0.0879
Psnr: 10.5605


In [34]:
model.save("cyclegan_model.keras")

In [35]:
model.save("cyclegan_model.h5")



In [36]:
generator_g = model.gen_G
generator_g.save("monet_generator.h5")



In [37]:
@tf.function(
    input_signature=[tf.TensorSpec(shape=[None, 256, 256, 3], dtype=tf.float32)]
)
def serve(input_img):
    return generator_g(input_img)


# Save with signatures
tf.saved_model.save(
    generator_g, "generator_saved_model", signatures={"serving_default": serve}
)

INFO:tensorflow:Assets written to: generator_saved_model/assets


INFO:tensorflow:Assets written to: generator_saved_model/assets


In [41]:
generator_g.save("monet_generator.keras")
generator_g.save("monet_generator.h5")



In [43]:
tf.saved_model.save(generator_g, "monet_generator/saved_model")

INFO:tensorflow:Assets written to: monet_generator/saved_model/assets


INFO:tensorflow:Assets written to: monet_generator/saved_model/assets
