In [None]:
import keras
from keras import layers
from keras import ops
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from load_data import load_data

In [None]:
package_path = r"C:\Users\nedst\Desktop\synoptic-project-NedStickler\.venv\Lib\site-packages\tensorflow_datasets"
dataset, labels = load_data(package_path)
processed_labels = keras.utils.to_categorical(labels)

num_channels = 3
num_classes = 45
latent_dim = 128
image_size = 256

discriminator_in_channels = num_channels + num_classes
generator_in_channels = latent_dim + num_classes

In [None]:
# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((256, 256, discriminator_in_channels)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        layers.Dense(32 * 32 * generator_in_channels),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((32, 32, generator_in_channels)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(3, (8, 8), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

In [None]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        real_images, labels = data
        real_images = real_images / 255
        batch_size = real_images.shape[0]
        
        # Reshape labels to be concatenated with generated examples
        added_dims_labels = labels[:, :, None, None]
        repeated_labels = ops.repeat(added_dims_labels, repeats=image_size**2)
        reshaped_labels = ops.reshape(repeated_labels, (-1, image_size, image_size, num_classes))

        # Create latent vectors and concatenate class labels for discriminator training
        latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)
        latent_vectors_with_labels = ops.concatenate([latent_vectors, labels], axis=1)

        # Generate images
        generated_images = self.generator(latent_vectors_with_labels)
        
        # Assemble images with labels
        generated_images_with_labels = ops.concatenate([generated_images, reshaped_labels], -1)
        real_images_with_labels = ops.concatenate([real_images, reshaped_labels], -1)
        combined_images = ops.concatenate([generated_images_with_labels, real_images_with_labels])

        # Create labels to train discriminator
        discriminator_labels = ops.concatenate([ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))])

        # Train discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(discriminator_labels, predictions)
        gradients = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(gradients, self.discriminator.trainable_weights))

        # Create latent vectors and concatenate class labels for generator training
        latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)
        latent_vectors_with_labels = ops.concatenate([latent_vectors, labels], axis=1)

        misleading_labels = ops.zeros((batch_size, 1))

        # Train generator
        with tf.GradientTape() as tape:
            generated_images = self.generator(latent_vectors_with_labels)
            generated_images_with_labels = ops.concatenate([generated_images, reshaped_labels], -1)
            predictions = self.discriminator(generated_images_with_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        gradients = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(gradients, self.generator.trainable_weights))

        # Update metrics
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        
        return {"g_loss": self.gen_loss_tracker.result(), "d_loss": self.disc_loss_tracker.result()}

In [None]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset[:2048, :, :, :], processed_labels[:2048, :], epochs=5)

In [None]:
label = np.zeros(45)
label[43] = 1
label = label.reshape((1, 45))
latent_vector = keras.random.normal(shape=(1, latent_dim))
latent_vector_with_label = ops.concatenate([latent_vector, label], axis=1)
latent_vector_with_label

generated_image = cond_gan.generator.predict(latent_vector_with_label)
processed_image = (generated_image[0] * 255).astype(np.uint8)

plt.imshow(processed_image.reshape((256, 256, 3)))
plt.show()
processed_image = processed_image.reshape(256, 256, 3)
processed_image