In [1]:
import keras
import tensorflow as tf
from keras import layers


2025-01-27 17:47:09.581427: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-27 17:47:09.591424: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737971229.603216  204486 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737971229.606444  204486 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-27 17:47:09.617922: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
IMAGE_SHAPE = (28, 28, 1)
BATCH_SIZE = 512

noise_dim = 128

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(f"Number of examples: {len(train_images)}")
print(f"Shape of the images: {train_images.shape[1:]}")

# Reshape the samples to (28, 28, 1)
train_images = train_images.reshape(train_images.shape[0], *IMAGE_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5

Number of examples: 60000
Shape of the images: (28, 28)


In [3]:
d_model = keras.Sequential(
    [
        layers.Input(shape=IMAGE_SHAPE), # (28, 28, 1)
        layers.ZeroPadding2D(padding=(2, 2)), # (32, 32, 1)
        layers.Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (16, 16, 64)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (8, 8, 128)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(256, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (4, 4, 256)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(512, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (2, 2, 512)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1), # (1,)
    ],
    name="discriminator",
)
d_model.summary()

I0000 00:00:1737971231.702136  204486 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6156 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:02:00.0, compute capability: 8.6


In [4]:
g_model = keras.Sequential(
    [
        layers.Input(shape=(noise_dim,)), # (128,)
        layers.Dense(4 * 4 * 256, use_bias=False), # (4096,)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((4, 4, 256)), # (4, 4, 256)
        layers.UpSampling2D(size=(2, 2)), # (8, 8, 256)
        layers.Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (8, 8, 128)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.UpSampling2D(size=(2, 2)), # (16, 16, 128)
        layers.Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (16, 16, 64)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.UpSampling2D(size=(2, 2)), # (32, 32, 64)
        layers.Conv2D(1, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (32, 32, 1)
        layers.BatchNormalization(),
        layers.Activation("tanh"),
        layers.Cropping2D(cropping=(2, 2)), # (28, 28, 1)
    ],
    name="generator",
)

g_model.summary()

In [None]:
class WGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim, discriminator_extra_steps=3, gp_weight=10.0):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    
    def gradient_penalty(self, batch_size, real_images, fake_images):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp