In [1]:
import os 
import yaml
import tensorflow as tf 
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go 

In [2]:
os.chdir('..')

from utils.GAN_utils import (build_generator, build_discriminator, GAN, 
                             generate_fake_images,
                             generate_real_and_fake_images
                             )

### Setup generator & discriminator configs

In [3]:
with open(os.path.join(os.getcwd(), "configs/GAN_config.yaml"), "r") as file:
    gan_config = yaml.safe_load(file)

    
generator_config = gan_config['generator']
discriminator_config = gan_config['discriminator']

### Load data - MNIST 

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
mnist_digits = np.concatenate([X_train, X_train], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

### Build generator and discriminator

In [6]:
generator = build_generator(generator_config)
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator_input (InputLayer) [(None, 2)]               0         
_________________________________________________________________
shape_prod (Dense)           (None, 6272)              18816     
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
generator_conv_1 (Conv2DTran (None, 14, 14, 128)       262272    
_________________________________________________________________
batch_norm_generator_1 (Batc (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_relu_generator_1 (Leak (None, 14, 14, 128)       0         
_________________________________________________________________
generator_conv_2 (Conv2DTran (None, 28, 28, 128)       26

In [8]:
discriminator = build_discriminator(discriminator_config)
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa [(None, 28, 28, 1)]       0         
_________________________________________________________________
discriminator_conv_1 (Conv2D (None, 14, 14, 64)        640       
_________________________________________________________________
batch_norm_discriminator_1 ( (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_relu_discriminator_1 ( (None, 14, 14, 64)        0         
_________________________________________________________________
discriminator_conv_2 (Conv2D (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_norm_discriminator_2 ( (None, 7, 7, 64)          256       
_________________________________________________________________
leaky_relu_discriminator_2 ( (None, 7, 7, 64)        

In [9]:
GAN = GAN(generator, discriminator)
GAN.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator (Functional)       (None, 28, 28, 1)         550657    
_________________________________________________________________
discriminator (Functional)   (None, 1)                 41217     
Total params: 591,874
Trainable params: 550,145
Non-trainable params: 41,729
_________________________________________________________________
