In [1]:
import os 
import yaml
import tensorflow as tf 
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go 
from typing import Dict, List, Union

In [2]:
os.chdir('..')

from utils.GAN_utils import (build_generator, build_discriminator, GAN, 
                             generate_fake_data,
                             generate_real_and_fake_data
                             )

### Setup generator & discriminator configs

In [3]:
with open(os.path.join(os.getcwd(), "configs/GAN_config.yaml"), "r") as file:
    gan_config = yaml.safe_load(file)

    
generator_config = gan_config['generator']
discriminator_config = gan_config['discriminator']

### Load data - MNIST 

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [5]:
mnist_digits = np.concatenate([X_train, X_train], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

### Build generator and discriminator

In [6]:
generator = build_generator(generator_config)
generator.summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator_input (InputLayer) [(None, 2)]               0         
_________________________________________________________________
shape_prod (Dense)           (None, 6272)              18816     
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
generator_conv_1 (Conv2DTran (None, 14, 14, 128)       262272    
_________________________________________________________________
batch_norm_generator_1 (Batc (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_relu_generator_1 (Leak (None, 14, 14, 128)       0         
_________________________________________________________________
generator_conv_2 (Conv2DTran (None, 28, 28, 128)       26

In [7]:
discriminator = build_discriminator(discriminator_config)
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa [(None, 28, 28, 1)]       0         
_________________________________________________________________
discriminator_conv_1 (Conv2D (None, 14, 14, 64)        640       
_________________________________________________________________
batch_norm_discriminator_1 ( (None, 14, 14, 64)        256       
_________________________________________________________________
leaky_relu_discriminator_1 ( (None, 14, 14, 64)        0         
_________________________________________________________________
discriminator_conv_2 (Conv2D (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_norm_discriminator_2 ( (None, 7, 7, 64)          256       
_________________________________________________________________
leaky_relu_discriminator_2 ( (None, 7, 7, 64)        

In [8]:
GAN = GAN(generator, discriminator)
GAN.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator (Functional)       (None, 28, 28, 1)         550657    
_________________________________________________________________
discriminator (Functional)   (None, 1)                 41217     
Total params: 591,874
Trainable params: 550,145
Non-trainable params: 41,729
_________________________________________________________________


In [13]:
def train(GAN: tf.keras.models.Model, 
          generator: tf.keras.models.Model,
          generator_config: Dict[str, Union[int, List[int], List[str]]],
          discriminator: tf.keras.models.Model,
          data: np.ndarray,
          n_epochs: int = 10,
          batch_size: int = 128):
    
    n_batch = data.shape[0] // batch_size 
    half_batch_size = batch_size // 2
    
    for i in range(n_epochs):
        for j in range(n_batch):
            
            X, y = generate_real_and_fake_data(
                data=data, 
                generator=generator, 
                generator_config=generator_config, 
                n_samples=half_batch_size
            )
            discriminator_loss = discriminator.train_on_batch(X, y)
            
            
            X_GAN, y_GAN = generate_fake_data(
                generator=generator, 
                generator_config=generator_config,
                n_samples=half_batch_size, 
                images=False, # generate X as noise not images
                inverse_labels=True 
            )
            GAN_loss = GAN.train_on_batch(X_GAN, y_GAN)
            
            if j % 10 == 0:
                print(f'EPOCH: {i} | BATCH: {j}')
                print('=' * 50)
                print(f'discriminator_loss:', discriminator_loss)
                print(f'GAN_loss          :', GAN_loss)
                print(f'-' * 50)

In [None]:
train(
    GAN=GAN,
    generator=generator,
    generator_config=generator_config,
    discriminator=discriminator,
    data=mnist_digits,
    n_epochs=10,
    batch_size=128
)

EPOCH: 0 | BATCH: 0
discriminator_loss: 0.000513640814460814
GAN_loss          : 0.3672392666339874
------------------------------
EPOCH: 0 | BATCH: 1
discriminator_loss: 0.0005492715863510966
GAN_loss          : 0.36262381076812744
------------------------------
EPOCH: 0 | BATCH: 2
discriminator_loss: 0.0006372646894305944
GAN_loss          : 0.3615275025367737
------------------------------
EPOCH: 0 | BATCH: 3
discriminator_loss: 0.0006209620041772723
GAN_loss          : 0.36155688762664795
------------------------------
EPOCH: 0 | BATCH: 4
discriminator_loss: 0.0007829570677131414
GAN_loss          : 0.35998380184173584
------------------------------
EPOCH: 0 | BATCH: 5
discriminator_loss: 0.0007511109579354525
GAN_loss          : 0.36034584045410156
------------------------------
EPOCH: 0 | BATCH: 6
discriminator_loss: 0.0007631541229784489
GAN_loss          : 0.3597460985183716
------------------------------
EPOCH: 0 | BATCH: 7
discriminator_loss: 0.0013501006178557873
GAN_loss   

GAN_loss          : 0.9284918904304504
------------------------------
EPOCH: 0 | BATCH: 51
discriminator_loss: 0.00359773775562644
GAN_loss          : 0.9660500288009644
------------------------------
EPOCH: 0 | BATCH: 52
discriminator_loss: 0.003264215774834156
GAN_loss          : 1.0032033920288086
------------------------------
EPOCH: 0 | BATCH: 53
discriminator_loss: 0.002974572591483593
GAN_loss          : 1.0415544509887695
------------------------------
EPOCH: 0 | BATCH: 54
discriminator_loss: 0.004165462218225002
GAN_loss          : 1.0871593952178955
------------------------------
EPOCH: 0 | BATCH: 55
discriminator_loss: 0.004812253639101982
GAN_loss          : 1.1451382637023926
------------------------------
EPOCH: 0 | BATCH: 56
discriminator_loss: 0.004340768326073885
GAN_loss          : 1.1991090774536133
------------------------------
EPOCH: 0 | BATCH: 57
discriminator_loss: 0.007629500236362219
GAN_loss          : 1.2612829208374023
------------------------------
EPOCH: 

GAN_loss          : 6.756747722625732
------------------------------
EPOCH: 0 | BATCH: 102
discriminator_loss: 0.693111777305603
GAN_loss          : 6.227655410766602
------------------------------
EPOCH: 0 | BATCH: 103
discriminator_loss: 0.8583762645721436
GAN_loss          : 6.281111717224121
------------------------------
EPOCH: 0 | BATCH: 104
discriminator_loss: 0.31516730785369873
GAN_loss          : 6.527483940124512
------------------------------
EPOCH: 0 | BATCH: 105
discriminator_loss: 0.19416555762290955
GAN_loss          : 6.476814270019531
------------------------------
EPOCH: 0 | BATCH: 106
discriminator_loss: 0.2857683002948761
GAN_loss          : 6.181631565093994
------------------------------
EPOCH: 0 | BATCH: 107
discriminator_loss: 0.4462594985961914
GAN_loss          : 5.533725738525391
------------------------------
EPOCH: 0 | BATCH: 108
discriminator_loss: 0.5641763210296631
GAN_loss          : 4.720935821533203
------------------------------
EPOCH: 0 | BATCH: 10

GAN_loss          : 2.3658995628356934
------------------------------
EPOCH: 0 | BATCH: 153
discriminator_loss: 0.29191818833351135
GAN_loss          : 2.669508934020996
------------------------------
EPOCH: 0 | BATCH: 154
discriminator_loss: 0.2550147473812103
GAN_loss          : 3.1175715923309326
------------------------------
EPOCH: 0 | BATCH: 155
discriminator_loss: 0.19183354079723358
GAN_loss          : 3.788466215133667
------------------------------
EPOCH: 0 | BATCH: 156
discriminator_loss: 0.1550951600074768
GAN_loss          : 4.113454341888428
------------------------------
EPOCH: 0 | BATCH: 157
discriminator_loss: 0.1050468385219574
GAN_loss          : 4.459429740905762
------------------------------
EPOCH: 0 | BATCH: 158
discriminator_loss: 0.030194202437996864
GAN_loss          : 4.889671325683594
------------------------------
EPOCH: 0 | BATCH: 159
discriminator_loss: 0.03874582797288895
GAN_loss          : 5.096742153167725
------------------------------
EPOCH: 0 | BAT

GAN_loss          : 2.411562919616699
------------------------------
EPOCH: 0 | BATCH: 204
discriminator_loss: 0.11256970465183258
GAN_loss          : 2.3928565979003906
------------------------------
EPOCH: 0 | BATCH: 205
discriminator_loss: 0.1251736581325531
GAN_loss          : 2.4708762168884277
------------------------------
EPOCH: 0 | BATCH: 206
discriminator_loss: 0.08180682361125946
GAN_loss          : 2.203498363494873
------------------------------
EPOCH: 0 | BATCH: 207
discriminator_loss: 0.21291695535182953
GAN_loss          : 1.8701505661010742
------------------------------
EPOCH: 0 | BATCH: 208
discriminator_loss: 0.3274056911468506
GAN_loss          : 1.663158893585205
------------------------------
EPOCH: 0 | BATCH: 209
discriminator_loss: 0.24623706936836243
GAN_loss          : 1.343416452407837
------------------------------
EPOCH: 0 | BATCH: 210
discriminator_loss: 0.5530053973197937
GAN_loss          : 1.325728416442871
------------------------------
EPOCH: 0 | BAT

GAN_loss          : 2.4298431873321533
------------------------------
EPOCH: 0 | BATCH: 255
discriminator_loss: 0.01842178776860237
GAN_loss          : 2.6946444511413574
------------------------------
EPOCH: 0 | BATCH: 256
discriminator_loss: 0.03790779411792755
GAN_loss          : 2.797248125076294
------------------------------
EPOCH: 0 | BATCH: 257
discriminator_loss: 0.009808365255594254
GAN_loss          : 2.3597536087036133
------------------------------
EPOCH: 0 | BATCH: 258
discriminator_loss: 0.059238631278276443
GAN_loss          : 2.253755807876587
------------------------------
EPOCH: 0 | BATCH: 259
discriminator_loss: 0.0632544606924057
GAN_loss          : 1.8293488025665283
------------------------------
EPOCH: 0 | BATCH: 260
discriminator_loss: 0.09730807691812515
GAN_loss          : 2.0942935943603516
------------------------------
EPOCH: 0 | BATCH: 261
discriminator_loss: 0.0635431781411171
GAN_loss          : 1.6033802032470703
------------------------------
EPOCH: 0

GAN_loss          : 5.385037899017334
------------------------------
EPOCH: 0 | BATCH: 306
discriminator_loss: 0.3893906772136688
GAN_loss          : 6.3944993019104
------------------------------
EPOCH: 0 | BATCH: 307
discriminator_loss: 0.2148190438747406
GAN_loss          : 6.7387285232543945
------------------------------
EPOCH: 0 | BATCH: 308
discriminator_loss: 0.3492506444454193
GAN_loss          : 7.277415752410889
------------------------------
EPOCH: 0 | BATCH: 309
discriminator_loss: 0.16315148770809174
GAN_loss          : 7.433344841003418
------------------------------
EPOCH: 0 | BATCH: 310
discriminator_loss: 0.15963992476463318
GAN_loss          : 7.5121002197265625
------------------------------
EPOCH: 0 | BATCH: 311
discriminator_loss: 0.06430923938751221
GAN_loss          : 7.464145660400391
------------------------------
EPOCH: 0 | BATCH: 312
discriminator_loss: 0.07358716428279877
GAN_loss          : 7.627452850341797
------------------------------
EPOCH: 0 | BATCH:

GAN_loss          : 4.913134574890137
------------------------------
EPOCH: 0 | BATCH: 357
discriminator_loss: 0.02555771730840206
GAN_loss          : 5.178258895874023
------------------------------
EPOCH: 0 | BATCH: 358
discriminator_loss: 0.007220904342830181
GAN_loss          : 5.429058074951172
------------------------------
EPOCH: 0 | BATCH: 359
discriminator_loss: 0.09196687489748001
GAN_loss          : 5.570276260375977
------------------------------
EPOCH: 0 | BATCH: 360
discriminator_loss: 0.0011649802327156067
GAN_loss          : 5.747781753540039
------------------------------
EPOCH: 0 | BATCH: 361
discriminator_loss: 0.0017441394738852978
GAN_loss          : 5.861375331878662
------------------------------
EPOCH: 0 | BATCH: 362
discriminator_loss: 0.08554726839065552
GAN_loss          : 6.020042896270752
------------------------------
EPOCH: 0 | BATCH: 363
discriminator_loss: 0.011177487671375275
GAN_loss          : 6.073498725891113
------------------------------
EPOCH: 0

GAN_loss          : 3.9477038383483887
------------------------------
EPOCH: 0 | BATCH: 408
discriminator_loss: 1.4414244890213013
GAN_loss          : 4.2975664138793945
------------------------------
EPOCH: 0 | BATCH: 409
discriminator_loss: 1.7642590999603271
GAN_loss          : 4.108818531036377
------------------------------
EPOCH: 0 | BATCH: 410
discriminator_loss: 2.153282642364502
GAN_loss          : 4.737420558929443
------------------------------
EPOCH: 0 | BATCH: 411
discriminator_loss: 2.50465726852417
GAN_loss          : 4.184560775756836
------------------------------
EPOCH: 0 | BATCH: 412
discriminator_loss: 1.8973312377929688
GAN_loss          : 4.5171709060668945
------------------------------
EPOCH: 0 | BATCH: 413
discriminator_loss: 2.305203437805176
GAN_loss          : 3.523827075958252
------------------------------
EPOCH: 0 | BATCH: 414
discriminator_loss: 2.0305604934692383
GAN_loss          : 2.76975679397583
------------------------------
EPOCH: 0 | BATCH: 415
d

GAN_loss          : 9.907087326049805
------------------------------
EPOCH: 0 | BATCH: 459
discriminator_loss: 0.009302593767642975
GAN_loss          : 7.299469947814941
------------------------------
EPOCH: 0 | BATCH: 460
discriminator_loss: 0.13364851474761963
GAN_loss          : 8.49979019165039
------------------------------
EPOCH: 0 | BATCH: 461
discriminator_loss: 0.007267548702657223
GAN_loss          : 6.64262056350708
------------------------------
EPOCH: 0 | BATCH: 462
discriminator_loss: 0.01741323620080948
GAN_loss          : 7.769357681274414
------------------------------
EPOCH: 0 | BATCH: 463
discriminator_loss: 0.017750345170497894
GAN_loss          : 5.148987293243408
------------------------------
EPOCH: 0 | BATCH: 464
discriminator_loss: 0.036934610456228256
GAN_loss          : 4.028539657592773
------------------------------
EPOCH: 0 | BATCH: 465
discriminator_loss: 0.0155279990285635
GAN_loss          : 2.945741891860962
------------------------------
EPOCH: 0 | BA

GAN_loss          : 0.45019206404685974
------------------------------
EPOCH: 0 | BATCH: 510
discriminator_loss: 0.022361094132065773
GAN_loss          : 0.3761904239654541
------------------------------
EPOCH: 0 | BATCH: 511
discriminator_loss: 0.00848347321152687
GAN_loss          : 0.29053306579589844
------------------------------
EPOCH: 0 | BATCH: 512
discriminator_loss: 0.020600100979208946
GAN_loss          : 0.227712482213974
------------------------------
EPOCH: 0 | BATCH: 513
discriminator_loss: 0.010991201736032963
GAN_loss          : 0.1509266495704651
------------------------------
EPOCH: 0 | BATCH: 514
discriminator_loss: 0.05591048300266266
GAN_loss          : 0.16664838790893555
------------------------------
EPOCH: 0 | BATCH: 515
discriminator_loss: 0.027096189558506012
GAN_loss          : 0.1228993833065033
------------------------------
EPOCH: 0 | BATCH: 516
discriminator_loss: 0.044765353202819824
GAN_loss          : 0.15465202927589417
-----------------------------

GAN_loss          : 2.500335216522217
------------------------------
EPOCH: 0 | BATCH: 561
discriminator_loss: 0.12682963907718658
GAN_loss          : 2.3618526458740234
------------------------------
EPOCH: 0 | BATCH: 562
discriminator_loss: 0.18583667278289795
GAN_loss          : 2.1690688133239746
------------------------------
EPOCH: 0 | BATCH: 563
discriminator_loss: 0.2088087797164917
GAN_loss          : 2.229311227798462
------------------------------
EPOCH: 0 | BATCH: 564
discriminator_loss: 0.4061065912246704
GAN_loss          : 2.5788183212280273
------------------------------
EPOCH: 0 | BATCH: 565
discriminator_loss: 0.2310255765914917
GAN_loss          : 2.725818157196045
------------------------------
EPOCH: 0 | BATCH: 566
discriminator_loss: 0.13570132851600647
GAN_loss          : 1.8475337028503418
------------------------------
EPOCH: 0 | BATCH: 567
discriminator_loss: 0.20981459319591522
GAN_loss          : 1.726766586303711
------------------------------
EPOCH: 0 | BA

GAN_loss          : 0.12225286662578583
------------------------------
EPOCH: 0 | BATCH: 612
discriminator_loss: 0.6880128979682922
GAN_loss          : 0.12436283379793167
------------------------------
EPOCH: 0 | BATCH: 613
discriminator_loss: 0.6367106437683105
GAN_loss          : 0.3376091718673706
------------------------------
EPOCH: 0 | BATCH: 614
discriminator_loss: 0.408172607421875
GAN_loss          : 0.31844139099121094
------------------------------
EPOCH: 0 | BATCH: 615
discriminator_loss: 0.47977930307388306
GAN_loss          : 0.41200321912765503
------------------------------
EPOCH: 0 | BATCH: 616
discriminator_loss: 0.3116954565048218
GAN_loss          : 0.5309925675392151
------------------------------
EPOCH: 0 | BATCH: 617
discriminator_loss: 0.23028264939785004
GAN_loss          : 0.7744518518447876
------------------------------
EPOCH: 0 | BATCH: 618
discriminator_loss: 0.4403434991836548
GAN_loss          : 0.5465670824050903
------------------------------
EPOCH: 0

GAN_loss          : 9.41421127319336
------------------------------
EPOCH: 0 | BATCH: 663
discriminator_loss: 0.019695289433002472
GAN_loss          : 9.153778076171875
------------------------------
EPOCH: 0 | BATCH: 664
discriminator_loss: 0.025149358436465263
GAN_loss          : 9.60955810546875
------------------------------
EPOCH: 0 | BATCH: 665
discriminator_loss: 0.02895212545990944
GAN_loss          : 10.404840469360352
------------------------------
EPOCH: 0 | BATCH: 666
discriminator_loss: 0.03494136780500412
GAN_loss          : 9.626669883728027
------------------------------
EPOCH: 0 | BATCH: 667
discriminator_loss: 0.013874760828912258
GAN_loss          : 9.883683204650879
------------------------------
EPOCH: 0 | BATCH: 668
discriminator_loss: 0.007732708007097244
GAN_loss          : 10.310161590576172
------------------------------
EPOCH: 0 | BATCH: 669
discriminator_loss: 0.03744307905435562
GAN_loss          : 9.525627136230469
------------------------------
EPOCH: 0 |

GAN_loss          : 8.079345703125
------------------------------
EPOCH: 0 | BATCH: 714
discriminator_loss: 0.003307800507172942
GAN_loss          : 8.033245086669922
------------------------------
EPOCH: 0 | BATCH: 715
discriminator_loss: 0.0011122049763798714
GAN_loss          : 8.027137756347656
------------------------------
EPOCH: 0 | BATCH: 716
discriminator_loss: 0.0007072027074173093
GAN_loss          : 7.978809356689453
------------------------------
EPOCH: 0 | BATCH: 717
discriminator_loss: 0.0015873742522671819
GAN_loss          : 7.9416680335998535
------------------------------
EPOCH: 0 | BATCH: 718
discriminator_loss: 0.0007364636985585093
GAN_loss          : 7.897696018218994
------------------------------
EPOCH: 0 | BATCH: 719
discriminator_loss: 0.0009427221957594156
GAN_loss          : 7.86380672454834
------------------------------
EPOCH: 0 | BATCH: 720
discriminator_loss: 0.0020823637023568153
GAN_loss          : 7.764058589935303
------------------------------
EPOC

In [None]:
def get_latent_space(generator : tf.keras.models.Model, 
                     n: int = 30, 
                     digit_size: int = 28, 
                     scale: float = 1.5):
    """
    Function to configure latent space ready for visual inspection
    
    :param generator: A trained generator for predictions
    :param n: number of steps in a grid like manner, e.g. n=30 would produce a 30 * 30 grid of predictions
    :param digit_size: The size of each image in each grid, e.g. digit_size=28 would produce a 28 * 28 digit
    :param scale: The min and max of the normal distributions, e.g. scale=1.5 would produce a y-axis and a x-axis both
    varying from -1.5 to 1.5.
    """
    
    latent_space = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = generator.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            
            latent_space[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit
    
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    return latent_space, sample_range_x, sample_range_y, pixel_range

In [None]:
fig = go.Figure()

fig.add_trace(
    go.Heatmap(
        x=np.arange(len(sample_range_x)),
        y=np.arange(len(sample_range_y)),
        z=latent_space,
        showscale=False
    )
)


fig.update_layout(
    height=800,
    width=800,
    margin=dict(b=0, l=0, r=0, t=20),
    title=dict(
        text='MNIST Digits Represented with a Multivariate Gaussian | Generative Adversarial Network',
        font=dict(size=11),
    ),
    xaxis=dict(
        title=dict(text='z[0]'),
        tickmode='array',
        tickfont=dict(size=10),
        tickvals=pixel_range,
        ticktext=sample_range_x
    ),
    yaxis=dict(
        title=dict(text='z[1]'),
        tickmode='array',
        tickfont=dict(size=10),
        tickvals=pixel_range,
        ticktext=sample_range_y,
        autorange='reversed'
    ),
)