In [None]:
import os
from pathlib import Path
os.chdir(Path(os.getcwd()).parent)
from src.load_data import load_data

import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [None]:
package_path = r"C:\Users\nedst\Desktop\synoptic-project-NedStickler\.venv\Lib\site-packages\tensorflow_datasets"
dataset = load_data(package_path)

In [None]:
input_shape = (256, 256, 3)
latent_dim = 128

discriminator = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        keras.layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(1, activation="sigmoid")
    ],
    name = "discriminator",
)

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        keras.layers.Dense(32 * 32 * 128),
        keras.layers.Reshape((32, 32, 128)),
        keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(3, kernel_size=4, activation="sigmoid", padding="same")
    ],
    name = "generator"
)

In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)

    def compile(self, d_optimiser, g_optimiser, loss_fn):
        super().compile()
        self.d_optimiser = d_optimiser
        self.g_optimiser = g_optimiser
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")
    
    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]
    
    def train_step(self, inputs):
        batch_size = keras.ops.shape(inputs)[0]
        inputs /= 255

        # generate random vectors and set up labels, random noise?
        random_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)
        generated_images = self.generator(random_latent_vectors)
        combined_images = keras.ops.concatenate([generated_images, inputs], axis=0)
        labels = keras.ops.concatenate([keras.ops.ones((batch_size, 1)), keras.ops.zeros((batch_size, 1))], axis=0)

        # train discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        gradients = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimiser.apply_gradients(zip(gradients, self.discriminator.trainable_weights))

        random_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)
        misleading_labels = keras.ops.zeros((batch_size, 1))

        # train generator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        gradients = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimiser.apply_gradients(zip(gradients, self.generator.trainable_weights))

        # update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)

        return {"d_loss": self.d_loss_metric.result(),"g_loss": self.g_loss_metric.result()}

epochs = 5
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(d_optimiser=keras.optimizers.Adam(learning_rate=0.0001), g_optimiser=keras.optimizers.Adam(learning_rate=0.0001), loss_fn=keras.losses.BinaryCrossentropy())
sample_dataset = dataset[:512, :, :, :]
gan.fit(sample_dataset, epochs=epochs, steps_per_epoch=8)

In [None]:
num_images = 25
random_latent_vectors = keras.random.normal(shape=(num_images, 128), seed=keras.random.SeedGenerator(42))
generated_images = gan.generator(random_latent_vectors)
processed_images = generated_images.numpy()
normalised_images = (processed_images * 255).astype(np.uint8)

for i in range(num_images):
    plt.imshow(processed_images[i])
    plt.show()