In [1]:
import keras
import tensorflow as tf
from keras import layers


2025-01-30 13:20:56.298372: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-30 13:20:56.309361: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738214456.322997   47002 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738214456.327217   47002 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-30 13:20:56.338871: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
IMAGE_SHAPE = (28, 28, 1)
BATCH_SIZE = 512

noise_dim = 128
from dataset.loader import SPOTS10
train_images, train_labels, test_images, test_labels = SPOTS10()
print(f"Number of examples: {len(train_images)}")
print(f"Shape of the images: {train_images.shape[1:]}")

# Crop the images to 28x28
train_images = train_images[:, 2:30, 2:30]
# Reshape the samples to (28, 28, 1)
train_images = train_images.reshape(train_images.shape[0], *IMAGE_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5

File dataset/test-images-idx3-ubyte.gz already exists, skipping download
File dataset/test-labels-idx1-ubyte.gz already exists, skipping download
File dataset/train-images-idx3-ubyte.gz already exists, skipping download
File dataset/train-labels-idx1-ubyte.gz already exists, skipping download
File utilities/spots_10_loader.py already exists, skipping download
All files downloaded successfully
Number of examples: 40000
Shape of the images: (32, 32)


In [3]:
d_model = keras.Sequential(
    [
        layers.Input(shape=IMAGE_SHAPE), # (28, 28, 1)
        layers.ZeroPadding2D(padding=(2, 2)), # (32, 32, 1)
        layers.Conv2D(64, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (16, 16, 64)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (8, 8, 128)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(256, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (4, 4, 256)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(512, kernel_size=(5, 5), strides=(2, 2), padding="same"), # (2, 2, 512)
        layers.LeakyReLU(negative_slope=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1), # (1,)
    ],
    name="discriminator",
)
d_model.summary()

I0000 00:00:1738214458.584083   47002 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6156 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:02:00.0, compute capability: 8.6


In [4]:
g_model = keras.Sequential(
    [
        layers.Input(shape=(noise_dim,)), # (128,)
        layers.Dense(4 * 4 * 256, use_bias=False), # (4096,)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((4, 4, 256)), # (4, 4, 256)
        layers.UpSampling2D(size=(2, 2)), # (8, 8, 256)
        layers.Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (8, 8, 128)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.UpSampling2D(size=(2, 2)), # (16, 16, 128)
        layers.Conv2D(64, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (16, 16, 64)
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.UpSampling2D(size=(2, 2)), # (32, 32, 64)
        layers.Conv2D(1, kernel_size=(3, 3), strides=(1, 1), padding="same", use_bias=False), # (32, 32, 1)
        layers.BatchNormalization(),
        layers.Activation("tanh"),
        layers.Cropping2D(cropping=(2, 2)), # (28, 28, 1)
    ],
    name="generator",
)

g_model.summary()

In [5]:
class WGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim, discriminator_extra_steps=3, gp_weight=10.0):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn
    
    def gradient_penalty(self, batch_size, real_images, fake_images):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images # (batch_size, 28, 28, 1)
        interpolated = real_images + alpha * diff # (batch_size, 28, 28, 1) after broadcasting

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the L2 norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])) # (batch_size,)
        gp = tf.reduce_mean((norm - 1.0) ** 2) # scalar
        return gp
    
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        # For each batch, we are going to perform the
        # following steps as laid out in the original paper:
        # 1. Train the generator and get the generator loss
        # 2. Train the discriminator and get the discriminator loss
        # 3. Calculate the gradient penalty
        # 4. Multiply this gradient penalty with a constant weight factor
        # 5. Add the gradient penalty to the discriminator loss
        # 6. Return the generator and discriminator losses as a loss dictionary

        # Train the discriminator first. The original paper recommends training
        # the discriminator for `x` more steps (typically 5) as compared to
        # one step of the generator. Here we will train it for 3 extra steps
        # as compared to 5 to reduce the training time.

        for i in range(self.d_steps):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            ) # (batch_size, 128)
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for the real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate the discriminator loss using the fake and real image logits
                d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        # Train the generator
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)
        
        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [6]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=6, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images = (generated_images * 127.5) + 127.5

        for i in range(self.num_img):
            img = generated_images[i].numpy()
            img = keras.utils.array_to_img(img)
            img.save("./generated_spots10/generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))


In [7]:
# Instantiate the optimizer for both networks
# (learning_rate=0.0002, beta_1=0.5 are recommended)
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)


# Define the loss functions for the discriminator,
# which should be (fake_loss - real_loss).
# We will add the gradient penalty later to this loss function.
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


# Set the number of epochs for training.
epochs = 20

# Instantiate the customer `GANMonitor` Keras callback.
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)

# Get the wgan model
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=3,
)

# Compile the wgan model
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

# Start training
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])

Epoch 1/20


I0000 00:00:1738214465.533708   47150 service.cc:148] XLA service 0x780c0c015580 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738214465.533767   47150 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3070, Compute Capability 8.6
2025-01-30 13:21:05.792725: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1738214466.730158   47150 cuda_dnn.cc:529] Loaded cuDNN version 90300

I0000 00:00:1738214487.626611   47150 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 437ms/step - d_loss: -1.7008 - g_loss: 6.3057
Epoch 2/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 218ms/step - d_loss: -1.9183 - g_loss: -6.4690
Epoch 3/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 220ms/step - d_loss: -1.4973 - g_loss: -7.2069
Epoch 4/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 220ms/step - d_loss: -1.2894 - g_loss: -4.3705
Epoch 5/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 221ms/step - d_loss: -1.3693 - g_loss: -5.7176
Epoch 6/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 221ms/step - d_loss: -1.5418 - g_loss: -3.6247
Epoch 7/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 221ms/step - d_loss: -1.6098 - g_loss: -2.1212
Epoch 8/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 222ms/step - d_loss: -1.4846 - g_loss: -1.7390
Epoch 9/20
[1m79/79[0m [3

<keras.src.callbacks.history.History at 0x780d1c0e59d0>

In [8]:
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])

Epoch 1/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 218ms/step - d_loss: -1.4257 - g_loss: -1.9087
Epoch 2/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 221ms/step - d_loss: -1.4199 - g_loss: -1.4415
Epoch 3/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 222ms/step - d_loss: -1.4140 - g_loss: -1.0445
Epoch 4/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 224ms/step - d_loss: -1.4049 - g_loss: -1.0742
Epoch 5/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 223ms/step - d_loss: -1.3969 - g_loss: -0.9166
Epoch 6/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 223ms/step - d_loss: -1.3697 - g_loss: -0.7089
Epoch 7/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 222ms/step - d_loss: -1.3551 - g_loss: -0.5464
Epoch 8/20
[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 223ms/step - d_loss: -1.3455 - g_loss: -0.3322
Epoch 9/20
[1m7

<keras.src.callbacks.history.History at 0x780cac517950>